diff --git a/qupulse/_program/__init__.py b/qupulse/_program/__init__.py index 93773ebb..f142e957 100644 --- a/qupulse/_program/__init__.py +++ b/qupulse/_program/__init__.py @@ -1 +1,79 @@ """This is a private package meaning there are no stability guarantees.""" +from abc import ABC, abstractmethod +from typing import Optional, Union, Sequence, ContextManager, Mapping + +import numpy as np + +from qupulse._program.waveforms import Waveform +from qupulse.utils.types import MeasurementWindow, TimeType +from qupulse._program.volatile import VolatileRepetitionCount + +try: + import qupulse_rs +except ImportError: + qupulse_rs = None + RsProgramBuilder = None +else: + from qupulse_rs.replacements import ProgramBuilder as RsProgramBuilder + +try: + from typing import Protocol, runtime_checkable +except ImportError: + Protocol = object + + def runtime_checkable(cls): + return cls + + +RepetitionCount = Union[int, VolatileRepetitionCount] + + +@runtime_checkable +class Program(Protocol): + """This protocol is used to inspect and or manipulate programs""" + + def to_single_waveform(self) -> Waveform: + pass + + def get_measurement_windows(self) -> Mapping[str, np.ndarray]: + pass + + @property + def duration(self) -> TimeType: + raise NotImplementedError() + + def make_compatible_inplace(self): + # TODO: rename? + pass + + +class ProgramBuilder(Protocol): + """This protocol is used by PulseTemplate to build the program.""" + + def append_leaf(self, waveform: Waveform, + measurements: Optional[Sequence[MeasurementWindow]] = None, + repetition_count: int = 1): + pass + + def potential_child(self, measurements: Optional[Sequence[MeasurementWindow]], + repetition_count: Union[VolatileRepetitionCount, int] = 1) -> ContextManager['ProgramBuilder']: + """ + + Args: + measurements: Measurements to attach to the potential child. Is not repeated with repetition_count. + repetition_count: + + Returns: + + """ + + def to_program(self) -> Optional[Program]: + pass + + +def default_program_builder() -> ProgramBuilder: + if RsProgramBuilder is None: + from qupulse._program._loop import Loop + return Loop() + else: + return RsProgramBuilder() diff --git a/qupulse/_program/_loop.py b/qupulse/_program/_loop.py index 92d3f7eb..6424d58a 100644 --- a/qupulse/_program/_loop.py +++ b/qupulse/_program/_loop.py @@ -1,4 +1,5 @@ -from typing import Union, Dict, Iterable, Tuple, cast, List, Optional, Generator, Mapping +import contextlib +from typing import Union, Dict, Iterable, Tuple, cast, List, Optional, Generator, Mapping, ContextManager, Sequence from collections import defaultdict from enum import Enum import warnings @@ -15,6 +16,7 @@ from qupulse.utils.tree import Node, is_tree_circular from qupulse.utils.numeric import smallest_factor_ge +from qupulse._program import ProgramBuilder, Program from qupulse._program.waveforms import SequenceWaveform, RepetitionWaveform __all__ = ['Loop', 'make_compatible', 'MakeCompatibleWarning'] @@ -101,6 +103,9 @@ def add_measurements(self, measurements: Iterable[MeasurementWindow]): Args: measurements: Measurements to add """ + warnings.warn("Loop.add_measurements is deprecated since qupulse 0.7 and will be removed in a future version.", + DeprecationWarning, + stacklevel=2) body_duration = float(self.body_duration) if body_duration == 0: measurements = measurements @@ -198,23 +203,47 @@ def encapsulate(self) -> None: self._measurements = None self.assert_tree_integrity() - def _get_repr(self, first_prefix, other_prefixes) -> Generator[str, None, None]: + def __repr__(self): + kwargs = [] + + repetition_count = self._repetition_definition + if repetition_count != 1: + kwargs.append(f"repetition_count={repetition_count!r}") + + waveform = self._waveform + if waveform: + kwargs.append(f"waveform={waveform!r}") + + children = self.children + if children: + try: + kwargs.append(f"children={self._children_repr()}") + except RecursionError: + kwargs.append("children=[...]") + + measurements = self._measurements + if measurements: + kwargs.append(f"measurements={measurements!r}") + + return f"Loop({','.join(kwargs)})" + + def _get_str(self, first_prefix, other_prefixes) -> Generator[str, None, None]: if self.is_leaf(): yield '%sEXEC %r %d times' % (first_prefix, self._waveform, self.repetition_count) else: yield '%sLOOP %d times:' % (first_prefix, self.repetition_count) for elem in self: - yield from cast(Loop, elem)._get_repr(other_prefixes + ' ->', other_prefixes + ' ') + yield from cast(Loop, elem)._get_str(other_prefixes + ' ->', other_prefixes + ' ') - def __repr__(self) -> str: + def __str__(self) -> str: is_circular = is_tree_circular(self) if is_circular: return '{}: Circ {}'.format(id(self), is_circular) str_len = 0 repr_list = [] - for sub_repr in self._get_repr('', ''): + for sub_repr in self._get_str('', ''): str_len += len(sub_repr) if self.MAX_REPR_SIZE and str_len > self.MAX_REPR_SIZE: @@ -404,6 +433,21 @@ def _merge_single_child(self): self._invalidate_duration() return True + @contextlib.contextmanager + def potential_child(self, + measurements: Optional[List[MeasurementWindow]], + repetition_count: Union[VolatileRepetitionCount, int] = 1) -> ContextManager['Loop']: + if repetition_count != 1 and measurements: + # current design requires an extra level of nesting here because the measurements are NOT to be repeated + # with the repetition count + inner_child = Loop(repetition_count=repetition_count) + child = Loop(measurements=measurements, children=[inner_child]) + else: + inner_child = child = Loop(measurements=measurements, repetition_count=repetition_count) + yield inner_child + if inner_child.waveform or len(inner_child): + self.append_child(child) + def cleanup(self, actions=('remove_empty_loops', 'merge_single_child')): """Apply the specified actions to cleanup the Loop. @@ -451,6 +495,32 @@ def get_duration_structure(self) -> Tuple[int, Union[TimeType, tuple]]: else: return self.repetition_count, tuple(child.get_duration_structure() for child in self) + def to_single_waveform(self) -> Waveform: + if self.is_leaf(): + if self.repetition_count == 1: + return self.waveform + else: + return RepetitionWaveform.from_repetition_count(self.waveform, self.repetition_count) + else: + if len(self) == 1: + sequenced_waveform = to_waveform(cast(Loop, self[0])) + else: + sequenced_waveform = SequenceWaveform.from_sequence([to_waveform(cast(Loop, sub_program)) + for sub_program in self]) + if self.repetition_count > 1: + return RepetitionWaveform.from_repetition_count(sequenced_waveform, self.repetition_count) + else: + return sequenced_waveform + + def append_leaf(self, waveform: Waveform, + measurements: Optional[Sequence[MeasurementWindow]] = None, + repetition_count: int = 1): + self.append_child(waveform=waveform, measurements=measurements, repetition_count=repetition_count) + + def to_program(self) -> Optional['Loop']: + if self.waveform or self.children: + return self + def reverse_inplace(self): if self.is_leaf(): self._waveform = self._waveform.reversed() @@ -465,6 +535,37 @@ def reverse_inplace(self): for name, begin, length in self._measurements ] + def make_compatible_inplace(self, minimal_waveform_length: int, waveform_quantum: int, sample_rate: TimeType): + program = self + comp_level = _is_compatible(program, + min_len=minimal_waveform_length, + quantum=waveform_quantum, + sample_rate=sample_rate) + if comp_level == _CompatibilityLevel.incompatible_fraction: + raise ValueError( + 'The program duration in samples {} is not an integer'.format(program.duration * sample_rate)) + if comp_level == _CompatibilityLevel.incompatible_too_short: + raise ValueError('The program is too short to be a valid waveform. \n' + ' program duration in samples: {} \n' + ' minimal length: {}'.format(program.duration * sample_rate, minimal_waveform_length)) + if comp_level == _CompatibilityLevel.incompatible_quantum: + raise ValueError('The program duration in samples {} ' + 'is not a multiple of quantum {}'.format(program.duration * sample_rate, waveform_quantum)) + + elif comp_level == _CompatibilityLevel.action_required: + warnings.warn( + "qupulse will now concatenate waveforms to make the pulse/program compatible with the chosen AWG." + " This might take some time. If you need this pulse more often it makes sense to write it in a " + "way which is more AWG friendly.", MakeCompatibleWarning) + + _make_compatible(program, + min_len=minimal_waveform_length, + quantum=waveform_quantum, + sample_rate=sample_rate) + + else: + assert comp_level == _CompatibilityLevel.compatible + class ChannelSplit(Exception): def __init__(self, channel_sets): @@ -472,22 +573,7 @@ def __init__(self, channel_sets): def to_waveform(program: Loop) -> Waveform: - if program.is_leaf(): - if program.repetition_count == 1: - return program.waveform - else: - return RepetitionWaveform.from_repetition_count(program.waveform, program.repetition_count) - else: - if len(program) == 1: - sequenced_waveform = to_waveform(cast(Loop, program[0])) - else: - sequenced_waveform = SequenceWaveform.from_sequence( - [to_waveform(cast(Loop, sub_program)) - for sub_program in program]) - if program.repetition_count > 1: - return RepetitionWaveform.from_repetition_count(sequenced_waveform, program.repetition_count) - else: - return sequenced_waveform + return program.to_single_waveform() class _CompatibilityLevel(Enum): @@ -568,32 +654,7 @@ def _make_compatible(program: Loop, min_len: int, quantum: int, sample_rate: Tim def make_compatible(program: Loop, minimal_waveform_length: int, waveform_quantum: int, sample_rate: TimeType): """ check program for compatibility to AWG requirements, make it compatible if necessary and possible""" - comp_level = _is_compatible(program, - min_len=minimal_waveform_length, - quantum=waveform_quantum, - sample_rate=sample_rate) - if comp_level == _CompatibilityLevel.incompatible_fraction: - raise ValueError('The program duration in samples {} is not an integer'.format(program.duration * sample_rate)) - if comp_level == _CompatibilityLevel.incompatible_too_short: - raise ValueError('The program is too short to be a valid waveform. \n' - ' program duration in samples: {} \n' - ' minimal length: {}'.format(program.duration * sample_rate, minimal_waveform_length)) - if comp_level == _CompatibilityLevel.incompatible_quantum: - raise ValueError('The program duration in samples {} ' - 'is not a multiple of quantum {}'.format(program.duration * sample_rate, waveform_quantum)) - - elif comp_level == _CompatibilityLevel.action_required: - warnings.warn("qupulse will now concatenate waveforms to make the pulse/program compatible with the chosen AWG." - " This might take some time. If you need this pulse more often it makes sense to write it in a " - "way which is more AWG friendly.", MakeCompatibleWarning) - - _make_compatible(program, - min_len=minimal_waveform_length, - quantum=waveform_quantum, - sample_rate=sample_rate) - - else: - assert comp_level == _CompatibilityLevel.compatible + program.make_compatible_inplace(minimal_waveform_length, waveform_quantum, sample_rate) def roll_constant_waveforms(program: Loop, minimal_waveform_quanta: int, waveform_quantum: int, sample_rate: TimeType): diff --git a/qupulse/_program/transformation.py b/qupulse/_program/transformation.py index 66ccfc04..734b7194 100644 --- a/qupulse/_program/transformation.py +++ b/qupulse/_program/transformation.py @@ -6,7 +6,15 @@ from qupulse import ChannelID from qupulse.comparable import Comparable -from qupulse.utils.types import SingletonABCMeta +from qupulse.utils.types import SingletonABCMeta, use_rs_replacements + +try: + import qupulse_rs +except ImportError: + qupulse_rs = None + transformation_rs = None +else: + from qupulse_rs.replacements import transformation as transformation_rs class Transformation(Comparable): @@ -325,4 +333,13 @@ def chain_transformations(*transformations: Transformation) -> Transformation: elif len(parsed_transformations) == 1: return parsed_transformations[0] else: - return ChainedTransformation(*parsed_transformations) \ No newline at end of file + return ChainedTransformation(*parsed_transformations) + + + +if transformation_rs: + use_rs_replacements(globals(), transformation_rs, Transformation) + + py_chain_transformations = chain_transformations + rs_chain_transformations = transformation_rs.chain_transformations + chain_transformations = rs_chain_transformations diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index e9ce8f6f..3adb1846 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -25,10 +25,18 @@ from qupulse.expressions import ExpressionScalar from qupulse.pulses.interpolation import InterpolationStrategy from qupulse.utils import checked_int_cast, isclose -from qupulse.utils.types import TimeType, time_from_float, FrozenDict +from qupulse.utils.types import TimeType, time_from_float, FrozenDict, use_rs_replacements from qupulse._program.transformation import Transformation from qupulse.utils import pairwise +try: + import qupulse_rs +except ImportError: + qupulse_rs = None + waveforms_rs = None +else: + from qupulse_rs.replacements import waveforms as waveforms_rs + class ConstantFunctionPulseTemplateWarning(UserWarning): """ This warning indicates a constant waveform is constructed from a FunctionPulseTemplate """ pass @@ -850,8 +858,8 @@ def unsafe_sample(self, return output_array @property - def compare_key(self) -> Tuple[Any, int]: - return self._body.compare_key, self._repetition_count + def compare_key(self) -> Tuple[int, Any]: + return self._repetition_count, self._body def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> Waveform: return RepetitionWaveform.from_repetition_count( @@ -987,9 +995,14 @@ def unsafe_sample(self, return self.inner_waveform.unsafe_sample(channel, sample_times, output_array) def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: - d = self._inner_waveform.constant_value_dict() - if d is not None: - return {ch: d[ch] for ch in self._channel_subset} + constant_values = {} + for ch in self.defined_channels: + value = self._inner_waveform.constant_value(ch) + if value is None: + return + else: + constant_values[ch] = value + return constant_values def constant_value(self, channel: ChannelID) -> Optional[float]: if channel not in self._channel_subset: @@ -1231,3 +1244,9 @@ def compare_key(self) -> Hashable: def reversed(self) -> 'Waveform': return self._inner + + + + +if waveforms_rs: + use_rs_replacements(globals(), waveforms_rs, Waveform) diff --git a/qupulse/expressions.py b/qupulse/expressions.py index b9b9b0d8..54926954 100644 --- a/qupulse/expressions.py +++ b/qupulse/expressions.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Union, Sequence, Callable, TypeVar, Type, Mapping from numbers import Number import warnings -import functools +import inspect import array import itertools @@ -18,6 +18,14 @@ get_most_simple_representation, get_variables, evaluate_lamdified_exact_rational from qupulse.utils.types import TimeType +try: + import qupulse_rs +except ImportError: + qupulse_rs = None + RsExpressionScalar = None +else: + from qupulse_rs.replacements import ExpressionScalar as RsExpressionScalar + __all__ = ["Expression", "ExpressionVariableMissingException", "ExpressionScalar", "ExpressionVector", "ExpressionLike"] @@ -84,6 +92,15 @@ def __call__(cls: Type[_ExpressionType], *args, **kwargs) -> _ExpressionType: else: return type.__call__(cls, *args, **kwargs) + if RsExpressionScalar is not None: + def __subclasscheck__(cls, subclass): + return cls.__name__ == subclass.__name__ or super().__subclasscheck__(subclass) + + def __instancecheck__(cls, instance): + if cls is ExpressionScalar or cls is Expression: + return isinstance(instance, RsExpressionScalar) or super().__instancecheck__(instance) + super().__instancecheck__(instance) + class Expression(AnonymousSerializable, metaclass=_ExpressionMeta): """Base class for expressions.""" @@ -236,7 +253,7 @@ def __eq__(self, other): other = Expression.make(other) except (ValueError, TypeError): return NotImplemented - if isinstance(other, ExpressionScalar): + if type(other).__name__ == 'ExpressionScalar': return self._expression_shape in ((), (1,)) and self._expression_items[0] == other.sympified_expression else: return self._expression_shape == other._expression_shape and \ @@ -335,7 +352,12 @@ def variables(self) -> Sequence[str]: @classmethod def _sympify(cls, other: Union['ExpressionScalar', Number, sympy.Expr]) -> sympy.Expr: - return other._sympified_expression if isinstance(other, cls) else sympify(other) + return sympify(other) + + @classmethod + def _extract_sympified(cls, other: Union['ExpressionScalar', Number, sympy.Expr]) \ + -> Union['ExpressionScalar', Number, sympy.Expr]: + return getattr(other, '_sympified_expression', other) @classmethod def _extract_sympified(cls, other: Union['ExpressionScalar', Number, sympy.Expr]) \ @@ -465,7 +487,7 @@ def __str__(self) -> str: str(self.expression), self.variable) -class NonNumericEvaluation(Exception): +class NonNumericEvaluation(TypeError): """An exception that is raised if the result of evaluate_numeric is not a number. See also: @@ -492,3 +514,21 @@ def __str__(self) -> str: ExpressionLike = TypeVar('ExpressionLike', str, Number, sympy.Expr, ExpressionScalar) + + +if RsExpressionScalar: + PyExpressionScalar = ExpressionScalar + + class ExpressionScalar(PyExpressionScalar): + def __new__(cls, *args, **kwargs): + if qupulse_rs and cls.__name__ == 'ExpressionScalar': + try: + return RsExpressionScalar(*args, **kwargs) + except (ValueError, TypeError, RuntimeError): + pass + return PyExpressionScalar.__new__(cls) + +assert isinstance(ExpressionScalar('a'), ExpressionScalar) +assert isinstance(ExpressionScalar('a'), Expression) +if RsExpressionScalar: + assert issubclass(RsExpressionScalar, ExpressionScalar) diff --git a/qupulse/parameter_scope.py b/qupulse/parameter_scope.py index a59f94a0..8423f34d 100644 --- a/qupulse/parameter_scope.py +++ b/qupulse/parameter_scope.py @@ -7,7 +7,15 @@ import itertools from qupulse.expressions import Expression, ExpressionVariableMissingException -from qupulse.utils.types import FrozenMapping, FrozenDict +from qupulse.utils.types import FrozenMapping, FrozenDict, use_rs_replacements + +try: + import qupulse_rs +except ImportError: + qupulse_rs = None + parameter_scope_rs = None +else: + from qupulse_rs.replacements import parameter_scope as parameter_scope_rs class Scope(Mapping[str, Number]): @@ -319,3 +327,7 @@ def __str__(self) -> str: class NonVolatileChange(RuntimeWarning): """Raised if a non volatile parameter is updated""" + + +if parameter_scope_rs: + use_rs_replacements(globals(), parameter_scope_rs, Scope) diff --git a/qupulse/pulses/arithmetic_pulse_template.py b/qupulse/pulses/arithmetic_pulse_template.py index 2e997472..1d765f47 100644 --- a/qupulse/pulses/arithmetic_pulse_template.py +++ b/qupulse/pulses/arithmetic_pulse_template.py @@ -1,4 +1,3 @@ - from typing import Any, Dict, List, Set, Optional, Union, Mapping, FrozenSet, cast, Callable from numbers import Real import warnings diff --git a/qupulse/pulses/constant_pulse_template.py b/qupulse/pulses/constant_pulse_template.py index 24ae1684..43c22030 100644 --- a/qupulse/pulses/constant_pulse_template.py +++ b/qupulse/pulses/constant_pulse_template.py @@ -10,6 +10,8 @@ import numbers from typing import Any, Dict, List, Optional, Union, Mapping, AbstractSet +from qupulse.utils import cached_property +from qupulse._program import ProgramBuilder from qupulse._program.waveforms import ConstantWaveform from qupulse.utils.types import TimeType, ChannelID from qupulse.utils import cached_property diff --git a/qupulse/pulses/loop_pulse_template.py b/qupulse/pulses/loop_pulse_template.py index f47520a7..b0ea8af4 100644 --- a/qupulse/pulses/loop_pulse_template.py +++ b/qupulse/pulses/loop_pulse_template.py @@ -13,7 +13,7 @@ from qupulse.parameter_scope import Scope, MappedScope, DictScope from qupulse.utils.types import FrozenDict, FrozenMapping -from qupulse._program._loop import Loop +from qupulse._program import ProgramBuilder from qupulse.expressions import ExpressionScalar, ExpressionVariableMissingException, Expression from qupulse.utils import checked_int_cast, cached_property @@ -149,26 +149,17 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional['Transformation'], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop) -> None: + parent_loop: ProgramBuilder) -> None: self.validate_scope(scope=scope) - try: - duration = self.duration.evaluate_in_scope(scope) - except ExpressionVariableMissingException as err: - raise ParameterNotProvidedException(err.variable) from err - - if duration > 0: - measurements = self.get_measurement_windows(scope, measurement_mapping) - if measurements: - parent_loop.add_measurements(measurements) - + with parent_loop.potential_child(measurements=self.get_measurement_windows(scope, measurement_mapping)) as for_loop: for local_scope in self._body_scope_generator(scope, forward=True): self.body._create_program(scope=local_scope, measurement_mapping=measurement_mapping, channel_mapping=channel_mapping, global_transformation=global_transformation, to_single_waveform=to_single_waveform, - parent_loop=parent_loop) + parent_loop=for_loop) def build_waveform(self, parameter_scope: Scope) -> ForLoopWaveform: return ForLoopWaveform([self.body.build_waveform(local_scope) diff --git a/qupulse/pulses/mapping_pulse_template.py b/qupulse/pulses/mapping_pulse_template.py index f235b28e..5a1a4316 100644 --- a/qupulse/pulses/mapping_pulse_template.py +++ b/qupulse/pulses/mapping_pulse_template.py @@ -9,7 +9,7 @@ from qupulse.pulses.pulse_template import PulseTemplate, MappingTuple from qupulse.pulses.parameters import Parameter, MappedParameter, ParameterNotProvidedException, ParameterConstrainer from qupulse._program.waveforms import Waveform -from qupulse._program._loop import Loop +from qupulse._program import ProgramBuilder from qupulse.serialization import Serializer, PulseRegistryType __all__ = [ @@ -202,7 +202,7 @@ def defined_channels(self) -> Set[ChannelID]: @property def duration(self) -> Expression: return self.__template.duration.evaluate_symbolic( - {parameter_name: expression.underlying_expression + {parameter_name: expression for parameter_name, expression in self.__parameter_mapping.items()} ) @@ -317,7 +317,7 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional['Transformation'], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop) -> None: + parent_loop: ProgramBuilder) -> None: self.validate_scope(scope) # parameters are validated in map_parameters() call, no need to do it here again explicitly diff --git a/qupulse/pulses/plotting.py b/qupulse/pulses/plotting.py index 8514a635..a021b417 100644 --- a/qupulse/pulses/plotting.py +++ b/qupulse/pulses/plotting.py @@ -14,7 +14,7 @@ import operator import itertools -from qupulse._program import waveforms +from qupulse._program import waveforms, Program from qupulse.utils.types import ChannelID, MeasurementWindow, has_type_interface from qupulse.pulses.pulse_template import PulseTemplate from qupulse.pulses.parameters import Parameter @@ -52,6 +52,12 @@ def render(program: Union[Loop], """ if has_type_interface(program, Loop): waveform, measurements = _render_loop(program, render_measurements=render_measurements) + elif isinstance(program, Program): + waveform = program.to_single_waveform() + measurements = program.get_measurement_windows() + measurements = [(name, begin, length) + for name, (begins, lengths) in measurements.items() + for begin, length in zip(begins, lengths)] else: raise ValueError('Cannot render an object of type %r' % type(program), program) diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index 1ed0786e..740dd63f 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -18,6 +18,7 @@ from qupulse.utils import forced_hash from qupulse.serialization import Serializable from qupulse.expressions import ExpressionScalar, Expression, ExpressionLike +from qupulse._program import ProgramBuilder, default_program_builder from qupulse._program._loop import Loop, to_waveform from qupulse._program.transformation import Transformation, IdentityTransformation, ChainedTransformation, chain_transformations @@ -165,7 +166,7 @@ def create_program(self, *, scope = DictScope(values=FrozenDict(parameters), volatile=volatile) - root_loop = Loop() + root_loop = default_program_builder() # call subclass specific implementation self._create_program(scope=scope, @@ -175,9 +176,7 @@ def create_program(self, *, to_single_waveform=to_single_waveform, parent_loop=root_loop) - if root_loop.waveform is None and len(root_loop.children) == 0: - return None # return None if no program - return root_loop + return root_loop.to_program() @abstractmethod def _internal_create_program(self, *, @@ -186,7 +185,7 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional[Transformation], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop) -> None: + parent_loop: ProgramBuilder) -> None: """The subclass specific implementation of create_program(). Receives a Loop instance parent_loop to which it should append measurements and its own Loops as children. @@ -207,36 +206,37 @@ def _create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional[Transformation], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop): + parent_loop: ProgramBuilder): """Generic part of create program. This method handles to_single_waveform and the configuration of the transformer.""" if self.identifier in to_single_waveform or self in to_single_waveform: - root = Loop() - if not scope.get_volatile_parameters().keys().isdisjoint(self.parameter_names): raise NotImplementedError('A pulse template that has volatile parameters cannot be transformed into a ' 'single waveform yet.') + builder = default_program_builder() self._internal_create_program(scope=scope, measurement_mapping=measurement_mapping, channel_mapping=channel_mapping, global_transformation=None, to_single_waveform=to_single_waveform, - parent_loop=root) + parent_loop=builder) - waveform = to_waveform(root) + program = builder.to_program() + if program is not None: + # we use the free function here for better testability + waveform = to_waveform(program) - if global_transformation: - waveform = TransformingWaveform.from_transformation(waveform, global_transformation) + if global_transformation: + waveform = TransformingWaveform.from_transformation(waveform, global_transformation) - # convert the nicely formatted measurement windows back into the old format again :( - measurements = root.get_measurement_windows() - measurement_window_list = [] - for measurement_name, (begins, lengths) in measurements.items(): - measurement_window_list.extend(zip(itertools.repeat(measurement_name), begins, lengths)) + # convert the nicely formatted measurement windows back into the old format again :( + measurements = program.get_measurement_windows() + measurement_window_list = [] + for measurement_name, (begins, lengths) in measurements.items(): + measurement_window_list.extend(zip(itertools.repeat(measurement_name), begins, lengths)) - parent_loop.add_measurements(measurement_window_list) - parent_loop.append_child(waveform=waveform) + parent_loop.append_leaf(waveform=waveform, measurements=measurement_window_list) else: self._internal_create_program(scope=scope, @@ -329,7 +329,7 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional[Transformation], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop) -> None: + parent_loop: ProgramBuilder) -> None: """Parameter constraints are validated in build_waveform because build_waveform is guaranteed to be called during sequencing""" ### current behavior (same as previously): only adds EXEC Loop and measurements if a waveform exists. @@ -345,8 +345,7 @@ def _internal_create_program(self, *, if global_transformation: waveform = TransformingWaveform.from_transformation(waveform, global_transformation) - parent_loop.add_measurements(measurements=measurements) - parent_loop.append_child(waveform=waveform) + parent_loop.append_leaf(waveform=waveform, measurements=measurements) @abstractmethod def build_waveform(self, diff --git a/qupulse/pulses/repetition_pulse_template.py b/qupulse/pulses/repetition_pulse_template.py index cf9f57a2..9e8ad9ec 100644 --- a/qupulse/pulses/repetition_pulse_template.py +++ b/qupulse/pulses/repetition_pulse_template.py @@ -8,7 +8,8 @@ import numpy as np from qupulse.serialization import Serializer, PulseRegistryType -from qupulse._program._loop import Loop, VolatileRepetitionCount +from qupulse._program import ProgramBuilder +from qupulse._program.volatile import VolatileRepetitionCount from qupulse.parameter_scope import Scope from qupulse.utils.types import ChannelID @@ -105,7 +106,7 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional['Transformation'], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop) -> None: + parent_loop: ProgramBuilder) -> None: self.validate_scope(scope) repetition_count = max(0, self.get_repetition_count_value(scope)) @@ -119,19 +120,15 @@ def _internal_create_program(self, *, else: repetition_definition = repetition_count - repj_loop = Loop(repetition_count=repetition_definition) - self.body._create_program(scope=scope, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, - global_transformation=global_transformation, - to_single_waveform=to_single_waveform, - parent_loop=repj_loop) - if repj_loop.waveform is not None or len(repj_loop.children) > 0: - measurements = self.get_measurement_windows(scope, measurement_mapping) - if measurements: - parent_loop.add_measurements(measurements) - - parent_loop.append_child(loop=repj_loop) + measurements = self.get_measurement_windows(scope, measurement_mapping) or None + + with parent_loop.potential_child(measurements, repetition_count=repetition_definition) as repj_loop: + self.body._create_program(scope=scope, + measurement_mapping=measurement_mapping, + channel_mapping=channel_mapping, + global_transformation=global_transformation, + to_single_waveform=to_single_waveform, + parent_loop=repj_loop) def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[str, Any]: data = super().get_serialization_data(serializer) diff --git a/qupulse/pulses/sequence_pulse_template.py b/qupulse/pulses/sequence_pulse_template.py index 80c06e4e..fca353d8 100644 --- a/qupulse/pulses/sequence_pulse_template.py +++ b/qupulse/pulses/sequence_pulse_template.py @@ -8,7 +8,7 @@ import warnings from qupulse.serialization import Serializer, PulseRegistryType -from qupulse._program._loop import Loop +from qupulse._program import ProgramBuilder from qupulse.parameter_scope import Scope from qupulse.utils import cached_property from qupulse.utils.types import MeasurementWindow, ChannelID, TimeType @@ -133,21 +133,17 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional['Transformation'], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop) -> None: + parent_loop: ProgramBuilder) -> None: self.validate_scope(scope) - if self.duration.evaluate_in_scope(scope) > 0: - measurements = self.get_measurement_windows(scope, measurement_mapping) - if measurements: - parent_loop.add_measurements(measurements) - + with parent_loop.potential_child(measurements=self.get_measurement_windows(scope, measurement_mapping)) as seq_loop: for subtemplate in self.subtemplates: subtemplate._create_program(scope=scope, measurement_mapping=measurement_mapping, channel_mapping=channel_mapping, global_transformation=global_transformation, to_single_waveform=to_single_waveform, - parent_loop=parent_loop) + parent_loop=seq_loop) def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[str, Any]: data = super().get_serialization_data(serializer) diff --git a/qupulse/serialization.py b/qupulse/serialization.py index 3825057e..b1eabaee 100644 --- a/qupulse/serialization.py +++ b/qupulse/serialization.py @@ -30,6 +30,7 @@ import gc import importlib import warnings +from typing import Protocol, runtime_checkable from contextlib import contextmanager from qupulse.utils.types import DocStringABCMeta, FrozenDict @@ -1064,7 +1065,7 @@ def default(self, o: Any) -> Any: else: return o.get_serialization_data() - elif isinstance(o, AnonymousSerializable): + elif hasattr(o, 'get_serialization_data'): return o.get_serialization_data() elif type(o) is set: @@ -1091,7 +1092,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def default(self, o: Any) -> Any: - if isinstance(o, AnonymousSerializable): + if hasattr(o, 'get_serialization_data'): return o.get_serialization_data() elif type(o) is set: return list(o) diff --git a/qupulse/utils/tree.py b/qupulse/utils/tree.py index 2585a5f5..32f460e3 100644 --- a/qupulse/utils/tree.py +++ b/qupulse/utils/tree.py @@ -152,6 +152,9 @@ def locate(self: _NodeType, location: Tuple[int, ...]) -> _NodeType: else: return self + def _children_repr(self): + return repr(self.__children) + def _reverse_children(self): """Reverse children in-place""" self.__children.reverse() diff --git a/qupulse/utils/types.py b/qupulse/utils/types.py index 7b467b0d..f59b497b 100644 --- a/qupulse/utils/types.py +++ b/qupulse/utils/types.py @@ -18,6 +18,12 @@ "will be removed in a future release.", category=DeprecationWarning) frozendict = None +try: + from qupulse_rs.qupulse_rs import TimeType as RsTimeType + numbers.Rational.register(RsTimeType) +except ImportError: + RsTimeType = None + import qupulse.utils.numeric as qupulse_numeric __all__ = ["MeasurementWindow", "ChannelID", "HashableNumpyArray", "TimeType", "time_from_float", "DocStringABCMeta", @@ -234,7 +240,7 @@ def __gt__(self, other): def __eq__(self, other): if type(other) == type(self): - return self._value.__eq__(other._value) + return self._value == other._value else: return self._value == other @@ -310,16 +316,22 @@ def __float__(self): return int(self._value.numerator) / int(self._value.denominator) +PyTimeType = TimeType + +if RsTimeType: + TimeType = RsTimeType + + # this asserts isinstance(TimeType, Rational) is True numbers.Rational.register(TimeType) _converter = { - float: TimeType.from_float, - TimeType._InternalType: TimeType, - fractions.Fraction: TimeType, - sympy.Rational: lambda q: TimeType.from_fraction(q.p, q.q), - TimeType: lambda x: x + float: PyTimeType.from_float, + PyTimeType._InternalType: PyTimeType, + fractions.Fraction: PyTimeType, + sympy.Rational: lambda q: PyTimeType.from_fraction(q.p, q.q), + PyTimeType: lambda x: x } @@ -558,3 +570,20 @@ def __eq__(self, other): return NotImplemented +def use_rs_replacements(glbls, rs_replacement, base_class: type): + name_suffix = base_class.__name__ + for name, rs_obj in vars(rs_replacement).items(): + if not name.endswith(name_suffix): + continue + + py_name = f'Py{name}' + rs_name = f'Rs{name}' + glbls[name] = rs_obj + try: + py_obj = glbls[name] + except KeyError: + pass + else: + glbls.setdefault(py_name, py_obj) + glbls.setdefault(rs_name, rs_obj) + base_class.register(rs_obj) diff --git a/setup.cfg b/setup.cfg index d85ed227..c430620b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,6 +33,7 @@ test_suite = tests tests = pytest pytest_benchmark + matplotlib docs = sphinx>=4 nbsphinx diff --git a/tests/_program/loop_tests.py b/tests/_program/loop_tests.py index 3bd811cd..d11df00c 100644 --- a/tests/_program/loop_tests.py +++ b/tests/_program/loop_tests.py @@ -143,6 +143,10 @@ def test_get_measurement_windows(self): self.assertEqual({}, prog.get_measurement_windows()) def test_repr(self): + tree = self.get_test_loop() + self.assertEqual(tree, eval(repr(tree))) + + def test_str(self): wf_gen = WaveformGenerator(num_channels=1) wfs = [wf_gen() for _ in range(11)] @@ -154,10 +158,10 @@ def test_repr(self): loop.waveform = wfs.pop(0) self.assertEqual(len(wfs), 0) - self.assertEqual(repr(tree), expected) + self.assertEqual(str(tree), expected) with mock.patch.object(Loop, 'MAX_REPR_SIZE', 1): - self.assertEqual(repr(tree), '...') + self.assertEqual(str(tree), '...') def test_is_leaf(self): root_loop = self.get_test_loop(waveform_generator=WaveformGenerator(1)) @@ -221,7 +225,7 @@ def test_flatten_and_balance(self): ->LOOP 9 times: ->EXEC {J} 10 times ->EXEC {K} 11 times""".format(**wf_reprs) - self.assertEqual(repr(before), before_repr) + self.assertEqual(str(before), before_repr) expected_after_repr = """\ LOOP 1 times: @@ -263,7 +267,7 @@ def test_flatten_and_balance(self): ->EXEC {J} 10 times ->EXEC {K} 11 times""".format(**wf_reprs) - self.assertEqual(expected_after_repr, repr(after)) + self.assertEqual(expected_after_repr, str(after)) def test_flatten_and_balance_comparison_based(self): wfs = [DummyWaveform(duration=i) for i in range(2)] @@ -425,29 +429,22 @@ def test_make_compatible_partial_unroll(self): program = Loop(children=[Loop(waveform=wf1, repetition_count=2), Loop(waveform=wf2)]) + expected_program = Loop(children=[ + Loop(waveform=RepetitionWaveform(wf1, 2)), + Loop(waveform=wf2) + ]) _make_compatible(program, min_len=1, quantum=1, sample_rate=TimeType.from_float(1.)) - - self.assertIsNone(program.waveform) - self.assertEqual(len(program), 2) - self.assertIsInstance(program[0].waveform, RepetitionWaveform) - self.assertIs(program[0].waveform._body, wf1) - self.assertEqual(program[0].waveform._repetition_count, 2) - self.assertIs(program[1].waveform, wf2) + self.assertEqual(expected_program, program) program = Loop(children=[Loop(waveform=wf1, repetition_count=2), - Loop(waveform=wf2)], repetition_count=2) + Loop(waveform=wf2)], repetition_count=3) + expected_program = Loop(waveform=SequenceWaveform([ + RepetitionWaveform(wf1, 2), + wf2 + ]), repetition_count=3) _make_compatible(program, min_len=5, quantum=1, sample_rate=TimeType.from_float(1.)) - - self.assertIsInstance(program.waveform, SequenceWaveform) - self.assertEqual(list(program.children), []) - self.assertEqual(program.repetition_count, 2) - - self.assertEqual(len(program.waveform._sequenced_waveforms), 2) - self.assertIsInstance(program.waveform._sequenced_waveforms[0], RepetitionWaveform) - self.assertIs(program.waveform._sequenced_waveforms[0]._body, wf1) - self.assertEqual(program.waveform._sequenced_waveforms[0]._repetition_count, 2) - self.assertIs(program.waveform._sequenced_waveforms[1], wf2) + self.assertEqual(expected_program, program) def test_make_compatible_complete_unroll(self): wf1 = DummyWaveform(duration=1.5) @@ -456,21 +453,18 @@ def test_make_compatible_complete_unroll(self): program = Loop(children=[Loop(waveform=wf1, repetition_count=2), Loop(waveform=wf2, repetition_count=1)], repetition_count=2) - _make_compatible(program, min_len=5, quantum=10, sample_rate=TimeType.from_float(1.)) - - self.assertIsInstance(program.waveform, RepetitionWaveform) - self.assertEqual(list(program.children), []) - self.assertEqual(program.repetition_count, 1) + expected_program = Loop(repetition_count=1, + waveform=RepetitionWaveform( + body=SequenceWaveform([ + RepetitionWaveform(wf1, 2), + wf2 + ]), + repetition_count=2 + ) + ) - self.assertIsInstance(program.waveform, RepetitionWaveform) - - self.assertIsInstance(program.waveform._body, SequenceWaveform) - body_wf = program.waveform._body - self.assertEqual(len(body_wf._sequenced_waveforms), 2) - self.assertIsInstance(body_wf._sequenced_waveforms[0], RepetitionWaveform) - self.assertIs(body_wf._sequenced_waveforms[0]._body, wf1) - self.assertEqual(body_wf._sequenced_waveforms[0]._repetition_count, 2) - self.assertIs(body_wf._sequenced_waveforms[1], wf2) + _make_compatible(program, min_len=5, quantum=10, sample_rate=TimeType.from_float(1.)) + self.assertEqual(expected_program, program) def test_make_compatible(self): program = Loop() diff --git a/tests/_program/seqc_tests.py b/tests/_program/seqc_tests.py index 6b2cfb0d..217ecff9 100644 --- a/tests/_program/seqc_tests.py +++ b/tests/_program/seqc_tests.py @@ -74,7 +74,7 @@ def make_binary_waveform(waveform): return (BinaryWaveform(data),) else: chs = sorted(waveform.defined_channels) - t = np.arange(0., float(waveform.duration), 1.) + t = np.arange(0., float(waveform.duration), 1., dtype=float) sampled = [None if ch is None else waveform.get_sampled(ch, t) for _, ch in zip_longest(range(6), take(6, chs), fillvalue=None)] diff --git a/tests/_program/transformation_tests.py b/tests/_program/transformation_tests.py index e75e17dc..ff4b6c03 100644 --- a/tests/_program/transformation_tests.py +++ b/tests/_program/transformation_tests.py @@ -2,10 +2,11 @@ from unittest import mock import numpy as np +import numpy.testing from qupulse._program.transformation import LinearTransformation, Transformation, IdentityTransformation,\ ChainedTransformation, ParallelConstantChannelTransformation, chain_transformations, OffsetTransformation,\ - ScalingTransformation + ScalingTransformation, transformation_rs class TransformationStub(Transformation): @@ -59,10 +60,11 @@ def test_compare_key_and_init(self): matrix_2 = np.array([[1, 1, 1], [1, 0, -1]]) trafo_2 = LinearTransformation(matrix_2, in_chs_2, out_chs_2) - self.assertEqual(trafo.compare_key, trafo_2.compare_key) + if transformation_rs is None: + self.assertEqual(trafo.compare_key, trafo_2.compare_key) + self.assertEqual(trafo.compare_key, (in_chs, out_chs, matrix.tobytes())) self.assertEqual(trafo, trafo_2) self.assertEqual(hash(trafo), hash(trafo_2)) - self.assertEqual(trafo.compare_key, (in_chs, out_chs, matrix.tobytes())) def test_from_pandas(self): try: @@ -93,14 +95,14 @@ def test_get_output_channels(self): def test_get_input_channels(self): in_chs = ('a', 'b', 'c') out_chs = ('transformed_a', 'transformed_b') - matrix = np.array([[1, -1, 0], [1, 1, 1]]) + matrix = np.array([[1., -1, 0], [1, 1, 1]]) trafo = LinearTransformation(matrix, in_chs, out_chs) self.assertEqual(trafo.get_input_channels({'transformed_a'}), {'a', 'b', 'c'}) self.assertEqual(trafo.get_input_channels({'transformed_a', 'd'}), {'a', 'b', 'c', 'd'}) self.assertEqual(trafo.get_input_channels({'d'}), {'d'}) with self.assertRaisesRegex(KeyError, 'Is input channel'): - self.assertEqual(trafo.get_input_channels({'transformed_a', 'a'}), {'a', 'b', 'c', 'd'}) + trafo.get_input_channels({'transformed_a', 'a'}) in_chs = ('a', 'b', 'c') out_chs = ('a', 'b', 'c') @@ -108,7 +110,7 @@ def test_get_input_channels(self): trafo = LinearTransformation(matrix, in_chs, out_chs) in_set = {'transformed_a'} - self.assertIs(trafo.get_input_channels(in_set), in_set) + self.assertEqual(trafo.get_input_channels(in_set), in_set) self.assertEqual(trafo.get_input_channels({'transformed_a', 'a'}), {'transformed_a', 'a', 'b', 'c'}) def test_call(self): @@ -143,7 +145,7 @@ def test_call(self): data_in = {'ignored': np.arange(116., 120.)} transformed = trafo(np.full(4, np.NaN), data_in) np.testing.assert_equal(transformed, data_in) - self.assertIs(data_in['ignored'], transformed['ignored']) + np.testing.assert_equal(data_in['ignored'], transformed['ignored']) def test_repr(self): in_chs = ('a', 'b', 'c') @@ -170,23 +172,25 @@ def test_constant_propagation(self): class IdentityTransformationTests(unittest.TestCase): def test_compare_key(self): - self.assertIsNone(IdentityTransformation().compare_key) + self.assertEqual(IdentityTransformation(), IdentityTransformation()) + self.assertEqual({IdentityTransformation()}, {IdentityTransformation(), IdentityTransformation()}) + @unittest.skipIf(transformation_rs is not None, "Not implemented yet for rust") def test_singleton(self): self.assertIs(IdentityTransformation(), IdentityTransformation()) def test_call(self): time = np.arange(12) data = dict(zip('abc',(np.arange(12.) + 1).reshape((3, 4)))) - self.assertIs(IdentityTransformation()(time, data), data) + self.assertEqual(IdentityTransformation()(time, data), data) def test_output_channels(self): chans = {'a', 'b'} - self.assertIs(IdentityTransformation().get_output_channels(chans), chans) + self.assertEqual(IdentityTransformation().get_output_channels(chans), chans) def test_input_channels(self): chans = {'a', 'b'} - self.assertIs(IdentityTransformation().get_input_channels(chans), chans) + self.assertEqual(IdentityTransformation().get_input_channels(chans), chans) def test_chain(self): trafo = TransformationStub() @@ -209,7 +213,13 @@ def test_init_and_properties(self): chained = ChainedTransformation(*trafos) self.assertEqual(chained.transformations, trafos) - self.assertIs(chained.transformations, chained.compare_key) + + def test_equality(self): + trafos1 = TransformationStub(), TransformationStub(), TransformationStub() + trafos2 = trafos1[0], trafos1[1], TransformationStub() + self.assertEqual(ChainedTransformation(*trafos1), ChainedTransformation(*trafos1)) + self.assertNotEqual(ChainedTransformation(*trafos1), ChainedTransformation(*trafos2)) + self.assertEqual({ChainedTransformation(*trafos1)}, {ChainedTransformation(*trafos1), ChainedTransformation(*trafos1)}) def test_get_output_channels(self): trafos = TransformationStub(), TransformationStub(), TransformationStub() @@ -221,7 +231,7 @@ def test_get_output_channels(self): mock.patch.object(trafos[2], 'get_output_channels', return_value=chans[2]) as get_output_channels_2: outs = chained.get_output_channels({0}) - self.assertIs(outs, chans[2]) + self.assertEqual(outs, chans[2]) get_output_channels_0.assert_called_once_with({0}) get_output_channels_1.assert_called_once_with({1}) get_output_channels_2.assert_called_once_with({2}) @@ -237,7 +247,7 @@ def test_get_input_channels(self): mock.patch.object(trafos[0], 'get_input_channels', return_value=chans[2]) as get_input_channels_2: outs = chained.get_input_channels({0}) - self.assertIs(outs, chans[2]) + self.assertEqual(outs, chans[2]) get_input_channels_0.assert_called_once_with({0}) get_input_channels_1.assert_called_once_with({1}) get_input_channels_2.assert_called_once_with({2}) @@ -253,18 +263,19 @@ def test_call(self): data_0 = dict(zip('abc', data + 42)) data_1 = dict(zip('abc', data + 2*42)) data_2 = dict(zip('abc', data + 3*42)) - with mock.patch('tests._program.transformation_tests.TransformationStub.__call__', - side_effect=[data_0, data_1, data_2]) as call: + with mock.patch.object(TransformationStub, '__call__', + side_effect=[data_0, data_1, data_2]) as call: outs = chained(time, data_in) - - self.assertIs(outs, data_2) self.assertEqual(call.call_count, 3) + numpy.testing.assert_equal(outs, data_2) + for ((time_arg, data_arg), kwargs), expected_data in zip(call.call_args_list, [data_in, data_0, data_1]): self.assertEqual(kwargs, {}) - self.assertIs(time, time_arg) - self.assertIs(expected_data, data_arg) + numpy.testing.assert_equal(time, time_arg) + numpy.testing.assert_equal(expected_data, data_arg) + @unittest.skipIf(transformation_rs is not None, "Not implemented for rust extension") def test_chain(self): trafos = TransformationStub(), TransformationStub() trafo = TransformationStub() @@ -279,6 +290,9 @@ def test_repr(self): trafo = ChainedTransformation(ScalingTransformation({'a': 1.1}), OffsetTransformation({'b': 6.6})) self.assertEqual(trafo, eval(repr(trafo))) + stub = TransformationStub() + self.assertEqual(repr(ChainedTransformation(stub, stub)), repr(ChainedTransformation(stub, stub))) + def test_constant_propagation(self): trafo = ChainedTransformation(ScalingTransformation({'a': 1.1}), OffsetTransformation({'b': 6.6})) self.assertTrue(trafo.is_constant_invariant()) @@ -292,10 +306,9 @@ def test_init(self): trafo = ParallelConstantChannelTransformation(channels) - self.assertEqual(trafo._channels, channels) - self.assertTrue(all(isinstance(v, float) for v in trafo._channels.values())) - - self.assertEqual(trafo.compare_key, (('X', 2.), ('Y', 4.4))) + if transformation_rs is None: + self.assertEqual(trafo._channels, channels) + self.assertTrue(all(isinstance(v, float) for v in trafo._channels.values())) self.assertEqual(trafo.get_input_channels(set()), set()) self.assertEqual(trafo.get_input_channels({'X'}), set()) @@ -306,6 +319,21 @@ def test_init(self): self.assertEqual(trafo.get_output_channels({'X'}), {'X', 'Y'}) self.assertEqual(trafo.get_output_channels({'X', 'Z'}), {'X', 'Y', 'Z'}) + def test_equality(self): + constants_1 = {'a': 1.3, 'b': 3.0} + constants_2 = {'a': 1.3, 'b': 3.1} + + + self.assertEqual(ParallelConstantChannelTransformation(constants_1), ParallelConstantChannelTransformation(constants_1)) + self.assertEqual(ParallelConstantChannelTransformation(constants_1), ParallelConstantChannelTransformation(constants_1.copy())) + self.assertNotEqual(ParallelConstantChannelTransformation(constants_1), ParallelConstantChannelTransformation(constants_2)) + + self.assertEqual({ParallelConstantChannelTransformation(constants_1)}, + {ParallelConstantChannelTransformation(constants_1), ParallelConstantChannelTransformation(constants_1)}) + self.assertEqual({ParallelConstantChannelTransformation(constants_1), ParallelConstantChannelTransformation(constants_2)}, + {ParallelConstantChannelTransformation(constants_1), + ParallelConstantChannelTransformation(constants_1), ParallelConstantChannelTransformation(constants_2)}) + def test_trafo(self): channels = {'X': 2, 'Y': 4.4} trafo = ParallelConstantChannelTransformation(channels) @@ -347,16 +375,16 @@ def test_constant_propagation(self): class TestChaining(unittest.TestCase): def test_identity_result(self): - self.assertIs(chain_transformations(), IdentityTransformation()) + self.assertEqual(chain_transformations(), IdentityTransformation()) - self.assertIs(chain_transformations(IdentityTransformation(), IdentityTransformation()), + self.assertEqual(chain_transformations(IdentityTransformation(), IdentityTransformation()), IdentityTransformation()) def test_single_transformation(self): trafo = TransformationStub() - self.assertIs(chain_transformations(trafo), trafo) - self.assertIs(chain_transformations(trafo, IdentityTransformation()), trafo) + self.assertEqual(chain_transformations(trafo), trafo) + self.assertEqual(chain_transformations(trafo, IdentityTransformation()), trafo) def test_denesting(self): trafo = TransformationStub() @@ -381,6 +409,7 @@ class TestOffsetTransformation(unittest.TestCase): def setUp(self) -> None: self.offsets = {'A': 1., 'B': 1.2} + @unittest.skipIf(transformation_rs is not None, "Not relevant for rust extension") def test_init(self): trafo = OffsetTransformation(self.offsets) # test copy @@ -391,13 +420,17 @@ def test_init(self): def test_get_input_channels(self): trafo = OffsetTransformation(self.offsets) channels = {'A', 'C'} - self.assertIs(channels, trafo.get_input_channels(channels)) - self.assertIs(channels, trafo.get_output_channels(channels)) + self.assertEqual(channels, trafo.get_input_channels(channels)) + self.assertEqual(channels, trafo.get_output_channels(channels)) - def test_compare_key(self): + def test_comparison(self): trafo = OffsetTransformation(self.offsets) - _ = hash(trafo) - self.assertEqual(frozenset([('A', 1.), ('B', 1.2)]), trafo.compare_key) + + self.assertEqual(OffsetTransformation(self.offsets.copy()), OffsetTransformation(self.offsets.copy())) + self.assertEqual({OffsetTransformation(self.offsets.copy())}, + {OffsetTransformation(self.offsets.copy()), OffsetTransformation(self.offsets.copy())}) + self.assertEqual({OffsetTransformation(self.offsets.copy()), OffsetTransformation({**self.offsets, 'C': 9})}, + {OffsetTransformation(self.offsets.copy()), OffsetTransformation({**self.offsets, 'C': 9})}) def test_trafo(self): trafo = OffsetTransformation(self.offsets) @@ -429,6 +462,7 @@ class TestScalingTransformation(unittest.TestCase): def setUp(self) -> None: self.scales = {'A': 1.5, 'B': 1.2} + @unittest.skipIf(transformation_rs is not None, "Only relevant for pure python") def test_init(self): trafo = ScalingTransformation(self.scales) # test copy @@ -438,13 +472,16 @@ def test_init(self): def test_get_input_channels(self): trafo = ScalingTransformation(self.scales) channels = {'A', 'C'} - self.assertIs(channels, trafo.get_input_channels(channels)) - self.assertIs(channels, trafo.get_output_channels(channels)) + self.assertEqual(channels, trafo.get_input_channels(channels)) + self.assertEqual(channels, trafo.get_output_channels(channels)) def test_compare_key(self): - trafo = OffsetTransformation(self.scales) - _ = hash(trafo) - self.assertEqual(frozenset([('A', 1.5), ('B', 1.2)]), trafo.compare_key) + other_scales = {**self.scales, 'H': 3.} + self.assertEqual(ScalingTransformation(self.scales), ScalingTransformation(self.scales)) + self.assertNotEqual(ScalingTransformation(self.scales), ScalingTransformation(other_scales)) + + self.assertEqual({ScalingTransformation(self.scales)}, {ScalingTransformation(self.scales), + ScalingTransformation(self.scales)}) def test_trafo(self): trafo = ScalingTransformation(self.scales) diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index e4930372..5d16357e 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -9,7 +9,7 @@ JumpInterpolationStrategy from qupulse._program.waveforms import MultiChannelWaveform, RepetitionWaveform, SequenceWaveform,\ TableWaveformEntry, TableWaveform, TransformingWaveform, SubsetWaveform, ArithmeticWaveform, ConstantWaveform,\ - Waveform, FunctorWaveform, FunctionWaveform, ReversedWaveform + Waveform, FunctorWaveform, FunctionWaveform, ReversedWaveform, waveforms_rs from qupulse._program.transformation import LinearTransformation from qupulse.expressions import ExpressionScalar, Expression @@ -150,9 +150,9 @@ def test_slot(self): class MultiChannelWaveformTest(unittest.TestCase): def test_init_no_args(self) -> None: - with self.assertRaises(ValueError): + with self.assertRaises((TypeError, ValueError)): MultiChannelWaveform(dict()) - with self.assertRaises(ValueError): + with self.assertRaises((TypeError, ValueError)): MultiChannelWaveform(None) def test_from_parallel(self): @@ -239,7 +239,9 @@ def test_unsafe_sample(self) -> None: result_a = waveform.unsafe_sample('A', sample_times, reuse_output) self.assertEqual(len(dwf_a.sample_calls), 2) self.assertIs(result_a, reuse_output) - self.assertIs(result_a, dwf_a.sample_calls[1][2]) + if waveforms_rs is None: + # rust extension cannot forward the numpy array back to python without performance degradation + self.assertIs(result_a, dwf_a.sample_calls[1][2]) numpy.testing.assert_equal(result_b, samples_b) def test_equality(self) -> None: @@ -247,10 +249,13 @@ def test_equality(self) -> None: dwf_b = DummyWaveform(duration=246.2, defined_channels={'B'}) dwf_c = DummyWaveform(duration=246.2, defined_channels={'C'}) waveform_a1 = MultiChannelWaveform([dwf_a, dwf_b]) - waveform_a2 = MultiChannelWaveform([dwf_a, dwf_b]) + waveform_a2 = MultiChannelWaveform([dwf_b, dwf_a]) waveform_a3 = MultiChannelWaveform([dwf_a, dwf_c]) + waveform_a4 = MultiChannelWaveform([dwf_a, dwf_b, dwf_c]) self.assertEqual(waveform_a1, waveform_a1) self.assertEqual(waveform_a1, waveform_a2) + self.assertEqual(waveform_a4.get_subset_for_channels({'A', 'B'}), waveform_a4.get_subset_for_channels({'B', 'A'})) + self.assertEqual({waveform_a1}, {waveform_a1, waveform_a2}) self.assertNotEqual(waveform_a1, waveform_a3) def test_unsafe_get_subset_for_channels(self): @@ -301,10 +306,10 @@ def __init__(self, *args, **kwargs): def test_init(self): body_wf = DummyWaveform() - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, OverflowError)): RepetitionWaveform(body_wf, -1) - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, TypeError)): RepetitionWaveform(body_wf, 1.1) wf = RepetitionWaveform(body_wf, 3) @@ -318,9 +323,7 @@ def test_from_repetition_count(self): self.assertEqual(RepetitionWaveform(dwf, 3), RepetitionWaveform.from_repetition_count(dwf, 3)) cwf = ConstantWaveform(duration=3, amplitude=2.2, channel='A') - with mock.patch.object(ConstantWaveform, 'from_mapping', return_value=mock.sentinel) as from_mapping: - self.assertIs(from_mapping.return_value, RepetitionWaveform.from_repetition_count(cwf, 5)) - from_mapping.assert_called_once_with(15, {'A': 2.2}) + self.assertEqual(ConstantWaveform.from_mapping(15, {'A': 2.2}), RepetitionWaveform.from_repetition_count(cwf, 5)) def test_duration(self): wf = RepetitionWaveform(DummyWaveform(duration=2.2), 3) @@ -328,12 +331,12 @@ def test_duration(self): def test_defined_channels(self): body_wf = DummyWaveform(defined_channels={'a'}) - self.assertIs(RepetitionWaveform(body_wf, 2).defined_channels, body_wf.defined_channels) + self.assertEqual(RepetitionWaveform(body_wf, 2).defined_channels, body_wf.defined_channels) def test_compare_key(self): body_wf = DummyWaveform(defined_channels={'a'}) wf = RepetitionWaveform(body_wf, 2) - self.assertEqual(wf.compare_key, (body_wf.compare_key, 2)) + self.assertEqual(wf.compare_key, (2, body_wf)) def test_unsafe_get_subset_for_channels(self): body_wf = DummyWaveform(defined_channels={'a', 'b'}) @@ -343,7 +346,7 @@ def test_unsafe_get_subset_for_channels(self): subset = RepetitionWaveform(body_wf, 3).get_subset_for_channels(chs) self.assertIsInstance(subset, RepetitionWaveform) self.assertIsInstance(subset._body, DummyWaveform) - self.assertIs(subset._body.defined_channels, chs) + self.assertEqual(subset._body.defined_channels, chs) self.assertEqual(subset._repetition_count, 3) def test_unsafe_sample(self): @@ -395,12 +398,12 @@ def test_init(self): swf1 = SequenceWaveform((dwf_ab, dwf_ab)) self.assertEqual(swf1.duration, 2*dwf_ab.duration) - self.assertEqual(len(swf1.compare_key), 2) + self.assertEqual(len(swf1.sequenced_waveforms), 2) swf2 = SequenceWaveform((swf1, dwf_ab)) self.assertEqual(swf2.duration, 3 * dwf_ab.duration) - self.assertEqual(len(swf2.compare_key), 2) + self.assertEqual(len(swf2.sequenced_waveforms), 2) def test_from_sequence(self): dwf = DummyWaveform(duration=1.1, defined_channels={'A'}) @@ -419,10 +422,10 @@ def test_from_sequence(self): cwf_3 = ConstantWaveform(duration=1.1, amplitude=3.3, channel='A') cwf_2_b = ConstantWaveform(duration=1.1, amplitude=2.2, channel='A') - with mock.patch.object(ConstantWaveform, 'from_mapping', return_value=mock.sentinel) as from_mapping: - new_constant = SequenceWaveform.from_sequence((cwf_2_a, cwf_2_b)) - self.assertIs(from_mapping.return_value, new_constant) - from_mapping.assert_called_once_with(2*TimeType.from_float(1.1), {'A': 2.2}) + new_constant = SequenceWaveform.from_sequence((cwf_2_a, cwf_2_b)) + expected_constant = ConstantWaveform.from_mapping(2*TimeType.from_float(1.1), {'A': 2.2}) + self.assertEqual(expected_constant, + new_constant) swf3 = SequenceWaveform.from_sequence((cwf_2_a, dwf)) self.assertEqual((cwf_2_a, dwf), swf3.sequenced_waveforms) @@ -434,6 +437,7 @@ def test_from_sequence(self): self.assertIsNone(swf3.constant_value('A')) assert_constant_consistent(self, swf3) + @unittest.skipIf(waveforms_rs is not None, "sentinel based test do not work with rust extension") def test_sample_times_type(self) -> None: with mock.patch.object(DummyWaveform, 'unsafe_sample') as unsafe_sample_patch: dwfs = (DummyWaveform(duration=1.), @@ -478,12 +482,12 @@ def test_unsafe_get_subset_for_channels(self): sub_wf = wf.unsafe_get_subset_for_channels(subset) self.assertIsInstance(sub_wf, SequenceWaveform) - self.assertEqual(len(sub_wf.compare_key), 2) - self.assertEqual(sub_wf.compare_key[0].defined_channels, subset) - self.assertEqual(sub_wf.compare_key[1].defined_channels, subset) + self.assertEqual(len(sub_wf.sequenced_waveforms), 2) + self.assertEqual(sub_wf.sequenced_waveforms[0].defined_channels, subset) + self.assertEqual(sub_wf.sequenced_waveforms[1].defined_channels, subset) - self.assertEqual(sub_wf.compare_key[0].duration, TimeType.from_float(2.2)) - self.assertEqual(sub_wf.compare_key[1].duration, TimeType.from_float(3.3)) + self.assertEqual(sub_wf.sequenced_waveforms[0].duration, TimeType.from_float(2.2)) + self.assertEqual(sub_wf.sequenced_waveforms[1].duration, TimeType.from_float(3.3)) def test_repr(self): cwf_2_a = ConstantWaveform(duration=1.1, amplitude=2.2, channel='A') @@ -589,7 +593,7 @@ def test_unsafe_get_subset_for_channels(self): entries = (TableWaveformEntry(0, 0, interp), TableWaveformEntry(2.1, -33.2, interp), TableWaveformEntry(5.7, 123.4, interp)) - waveform = TableWaveform('A', entries) + waveform = TableWaveform.from_table('A', entries) self.assertIs(waveform.unsafe_get_subset_for_channels({'A'}), waveform) def test_unsafe_sample(self) -> None: @@ -597,7 +601,7 @@ def test_unsafe_sample(self) -> None: entries = (TableWaveformEntry(0, 0, interp), TableWaveformEntry(2.1, -33.2, interp), TableWaveformEntry(5.7, 123.4, interp)) - waveform = TableWaveform('A', entries) + waveform = TableWaveform.from_table('A', entries) sample_times = numpy.linspace(.5, 5.5, num=11) expected_interp_arguments = [((0, 0), (2.1, -33.2), [0.5, 1.0, 1.5, 2.0]), @@ -606,7 +610,8 @@ def test_unsafe_sample(self) -> None: result = waveform.unsafe_sample('A', sample_times) - self.assertEqual(expected_interp_arguments, interp.call_arguments) + if waveforms_rs is None: + self.assertEqual(expected_interp_arguments, interp.call_arguments) numpy.testing.assert_equal(expected_result, result) output_expected = numpy.empty_like(expected_result) @@ -780,7 +785,7 @@ def test_simple_properties(self): self.assertIs(subset_wf.inner_waveform, inner_wf) self.assertEqual(subset_wf.compare_key, (frozenset(['a', 'c']), inner_wf)) - self.assertIs(subset_wf.duration, inner_wf.duration) + self.assertEqual(subset_wf.duration, inner_wf.duration) self.assertEqual(subset_wf.defined_channels, {'a', 'c'}) def test_get_subset_for_channels(self): @@ -795,7 +800,8 @@ def test_get_subset_for_channels(self): get_subset_for_channels.assert_called_once_with({'a'}) self.assertIs(subsetted, actual_subsetted) - def test_unsafe_sample(self): + @unittest.skipIf(waveforms_rs is not None, "Test requires pure python.") + def test_unsafe_sample_pure(self): """Test perfect forwarding""" time = {'time'} output = {'output'} @@ -811,6 +817,27 @@ def test_unsafe_sample(self): self.assertIs(expected_data, actual_data) unsafe_sample.assert_called_once_with('g', time, output) + def test_unsafe_sample_pure(self): + """Test perfect forwarding""" + time = np.arange(0., 1., 17) + + output_values = np.sin(time + 1e-4) + sample_output = dict( + a=output_values + 3, + b=output_values + 9, + c=output_values + 17, + ) + + inner_wf = DummyWaveform(sample_output=sample_output, duration=2.) + + subset_wf = SubsetWaveform(inner_wf, {'a', 'c'}) + + for ch in 'ac': + output_place = np.full_like(time, np.nan) + output = subset_wf.unsafe_sample(ch, sample_times=time, output_array=output_place) + self.assertIs(output, output_place) + numpy.testing.assert_equal(sample_output[ch], output) + class ArithmeticWaveformTest(unittest.TestCase): def test_from_operator(self): diff --git a/tests/expression_tests.py b/tests/expression_tests.py index 26693821..9ff11f6d 100644 --- a/tests/expression_tests.py +++ b/tests/expression_tests.py @@ -6,9 +6,15 @@ import sympy.abc from sympy import sympify, Eq -from qupulse.expressions import Expression, ExpressionVariableMissingException, NonNumericEvaluation, ExpressionScalar, ExpressionVector +from qupulse.expressions import Expression, ExpressionVariableMissingException, NonNumericEvaluation, ExpressionScalar,\ + ExpressionVector, qupulse_rs from qupulse.utils.types import TimeType +try: + from qupulse.expressions import PyExpressionScalar +except ImportError: + PyExpressionScalar = ExpressionScalar + class ExpressionTests(unittest.TestCase): def test_make(self): self.assertTrue(Expression.make('a') == 'a') @@ -17,7 +23,10 @@ def test_make(self): self.assertIsInstance(Expression.make([1, 'a']), ExpressionVector) - self.assertIsInstance(ExpressionScalar.make('a'), ExpressionScalar) + if qupulse_rs: + self.assertEqual(type(ExpressionScalar.make('a')).__name__, 'ExpressionScalar') + else: + self.assertIsInstance(ExpressionScalar.make('a'), ExpressionScalar) self.assertIsInstance(ExpressionVector.make(['a']), ExpressionVector) @@ -155,7 +164,7 @@ def test_evaluate_numeric(self) -> None: } self.assertEqual(2 * 1.5 - 7, e.evaluate_numeric(**params)) - with self.assertRaises(NonNumericEvaluation): + with self.assertRaises((NonNumericEvaluation, TypeError)): params['a'] = sympify('h') e.evaluate_numeric(**params) @@ -239,7 +248,7 @@ def test_evaluate_numeric_without_numpy(self): 'b': sympify('k'), 'c': -7 } - with self.assertRaises(NonNumericEvaluation): + with self.assertRaises(TypeError): e.evaluate_numeric(**params) def test_evaluate_symbolic(self): @@ -249,7 +258,7 @@ def test_evaluate_symbolic(self): 'c': -7 } result = e.evaluate_symbolic(params) - expected = ExpressionScalar('d*b-7') + expected = PyExpressionScalar('d*b-7') self.assertEqual(result, expected) def test_variables(self) -> None: @@ -296,7 +305,10 @@ def test_repr_original_expression_is_sympy(self): def test_str(self): s = 'a * b' e = ExpressionScalar(s) - self.assertEqual('a*b', str(e)) + if qupulse_rs is None: + self.assertEqual('a*b', str(e)) + else: + self.assertEqual(s, str(e)) def test_original_expression(self): s = 'a * b' @@ -423,6 +435,36 @@ def test_special_function_numeric_evaluation(self): np.testing.assert_allclose(expected, result) + def test_rounding_equality(self): + seconds2ns = 1e9 + pulse_duration = 1.0765001496284785e-07 + float_product = pulse_duration * seconds2ns + + expr_1 = ExpressionScalar(pulse_duration) + expr_2 = ExpressionScalar(seconds2ns) + + self.assertEqual(expr_1, pulse_duration) + self.assertEqual(expr_2, seconds2ns) + + self.assertEqual(expr_1.sympified_expression, pulse_duration) + self.assertEqual(expr_2.sympified_expression, seconds2ns) + + expr_a = ExpressionScalar(float_product) + expr_b = expr_1 * seconds2ns + expr_c = expr_2 * pulse_duration + + #self.assertEqual(float_product, float(expr_a)) + #self.assertEqual(float_product, float(expr_b)) + #self.assertEqual(float_product, float(expr_c)) + + self.assertEqual(float_product, expr_a) + self.assertEqual(float_product, expr_b) + self.assertEqual(float_product, expr_c) + + expr_symb = ExpressionScalar('duration') + expr_d = expr_symb.evaluate_symbolic(substitutions={'duration': float_product}) + self.assertEqual(float_product, expr_d) + def test_evaluate_with_exact_rationals(self): expr = ExpressionScalar('1 / 3') self.assertEqual(TimeType.from_fraction(1, 3), expr.evaluate_with_exact_rationals({})) diff --git a/tests/hardware/tektronix_tests.py b/tests/hardware/tektronix_tests.py index a1f326b6..d25b1fc2 100644 --- a/tests/hardware/tektronix_tests.py +++ b/tests/hardware/tektronix_tests.py @@ -110,7 +110,7 @@ def test_parse_program(self): ill_formed_program = Loop(children=[Loop(children=[Loop()])]) with self.assertRaisesRegex(AssertionError, 'Invalid program depth'): - parse_program(ill_formed_program, (), (), TimeType(), (), (), ()) + parse_program(ill_formed_program, (), (), TimeType(0), (), (), ()) channels = ('A', 'B', None, None) markers = (('A1', None), (None, None), (None, 'C2'), (None, None)) diff --git a/tests/parameter_scope_tests.py b/tests/parameter_scope_tests.py index 1a2807d6..bd3b3f5f 100644 --- a/tests/parameter_scope_tests.py +++ b/tests/parameter_scope_tests.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -from qupulse.parameter_scope import Scope, DictScope, MappedScope, ParameterNotProvidedException, NonVolatileChange +from qupulse.parameter_scope import Scope, DictScope, MappedScope, ParameterNotProvidedException, NonVolatileChange, parameter_scope_rs from qupulse.expressions import ExpressionScalar from qupulse.utils.types import FrozenDict @@ -9,17 +9,20 @@ class DictScopeTests(unittest.TestCase): def test_init(self): - with self.assertRaises(AssertionError): - DictScope(dict()) + if parameter_scope_rs is None: + with self.assertRaises(AssertionError): + DictScope(dict()) fd = FrozenDict({'a': 2}) ds = DictScope(fd) - self.assertIs(fd, ds._values) - self.assertEqual(FrozenDict(), ds._volatile_parameters) + if parameter_scope_rs is None: + self.assertIs(fd, ds._values) + self.assertEqual(FrozenDict(), ds.get_volatile_parameters()) vp = frozenset('a') ds = DictScope(fd, vp) - self.assertIs(fd, ds._values) - self.assertEqual(FrozenDict(a=ExpressionScalar('a')), ds._volatile_parameters) + if parameter_scope_rs is None: + self.assertIs(fd, ds._values) + self.assertEqual(FrozenDict(a=ExpressionScalar('a')), ds.get_volatile_parameters()) def test_mapping(self): ds = DictScope(FrozenDict({'a': 1, 'b': 2})) @@ -136,6 +139,7 @@ def test_mapping(self): with self.assertRaisesRegex(KeyError, 'd'): _ = ms['d'] + @unittest.skipIf(parameter_scope_rs is not None, "Tested method not present in rust") def test_parameter(self): mock_a = mock.Mock(wraps=1) mock_result = mock.Mock() @@ -163,9 +167,10 @@ def test_parameter(self): def test_update_constants(self): ds = DictScope.from_kwargs(a=1, b=2, c=3, volatile={'c'}) ds2 = DictScope.from_kwargs(a=1, b=2, c=4, volatile={'c'}) - ms = MappedScope(ds, FrozenDict(x=ExpressionScalar('a * b'), - c=ExpressionScalar('a - b'))) - ms2 = MappedScope(ds2, ms._mapping) + mapping = FrozenDict(x=ExpressionScalar('a * b'), + c=ExpressionScalar('a - b')) + ms = MappedScope(ds, mapping) + ms2 = MappedScope(ds2, mapping) self.assertIs(ms, ms.change_constants({'f': 1})) @@ -180,7 +185,7 @@ def test_volatile_parameters(self): y=ExpressionScalar('c - a'))) expected_volatile = FrozenDict(d=ExpressionScalar('d'), y=ExpressionScalar('c - 1')) self.assertEqual(expected_volatile, ms.get_volatile_parameters()) - self.assertIs(ms.get_volatile_parameters(), ms.get_volatile_parameters()) + self.assertEqual(ms.get_volatile_parameters(), ms.get_volatile_parameters()) def test_eq(self): ds1 = DictScope.from_kwargs(a=1, b=2, c=3, d=4) diff --git a/tests/pulses/arithmetic_pulse_template_tests.py b/tests/pulses/arithmetic_pulse_template_tests.py index 87c1f65b..47b7ccea 100644 --- a/tests/pulses/arithmetic_pulse_template_tests.py +++ b/tests/pulses/arithmetic_pulse_template_tests.py @@ -481,7 +481,18 @@ def test_integral(self): expected = dict(u=ExpressionScalar('ui / (x + y)'), v=ExpressionScalar('vi / 2.2'), w=ExpressionScalar('wi')) - self.assertEqual(expected, ArithmeticPulseTemplate(pt, '/', mapping).integral) + actual = ArithmeticPulseTemplate(pt, '/', mapping).integral + self.assertEqual(expected, actual) + + def test_initial_values(self): + lhs = DummyPulseTemplate(initial_values={'A': .3, 'B': 'b'}, defined_channels={'A', 'B'}) + apt = lhs + 'a' + self.assertEqual({'A': 'a + 0.3', 'B': 'b + a'}, apt.initial_values) + + def test_final_values(self): + lhs = DummyPulseTemplate(final_values={'A': .3, 'B': 'b'}, defined_channels={'A', 'B'}) + apt = lhs - 'a' + self.assertEqual({'A': '-a + .3', 'B': 'b - a'}, apt.final_values) def test_initial_values(self): lhs = DummyPulseTemplate(initial_values={'A': .3, 'B': 'b'}, defined_channels={'A', 'B'}) diff --git a/tests/pulses/constant_pulse_template_tests.py b/tests/pulses/constant_pulse_template_tests.py index ead1ce59..3651fe6e 100644 --- a/tests/pulses/constant_pulse_template_tests.py +++ b/tests/pulses/constant_pulse_template_tests.py @@ -51,7 +51,7 @@ def test_regression_duration_conversion(self): for duration_in_samples in [64, 936320, 24615392]: p = ConstantPulseTemplate(duration_in_samples / 2.4, {'a': 0}) number_of_samples = p.create_program().duration * 2.4 - make_compatible(p.create_program(), 8, 8, 2.4) + make_compatible(p.create_program(), 8, 8, qupulse.utils.types.TimeType.from_float(2.4)) self.assertEqual(number_of_samples.denominator, 1) p2 = ConstantPulseTemplate((duration_in_samples + 1) / 2.4, {'a': 0}) diff --git a/tests/pulses/function_pulse_tests.py b/tests/pulses/function_pulse_tests.py index 27a78ba5..536bdc0a 100644 --- a/tests/pulses/function_pulse_tests.py +++ b/tests/pulses/function_pulse_tests.py @@ -85,12 +85,12 @@ def test_integral(self) -> None: self.assertEqual({'default': Expression('2.0*cos(b) - 2.0*cos(1.0*Tmax+b)')}, pulse.integral) def test_initial_values(self): - fpt = FunctionPulseTemplate('3 + exp(t * a)', 'pi', channel='A') + fpt = FunctionPulseTemplate('3 + exp(t * a)', '3.14', channel='A') self.assertEqual({'A': 4}, fpt.initial_values) def test_final_values(self): - fpt = FunctionPulseTemplate('3 + exp(t * a)', 'pi', channel='A') - self.assertEqual({'A': Expression('3 + exp(pi*a)')}, fpt.final_values) + fpt = FunctionPulseTemplate('3 + exp(t * a)', '3.14', channel='A') + self.assertEqual({'A': Expression('3 + exp(3.14*a)')}, fpt.final_values) def test_as_expression(self): pulse = FunctionPulseTemplate('sin(0.5*t+b)', '2*Tmax') diff --git a/tests/pulses/loop_pulse_template_tests.py b/tests/pulses/loop_pulse_template_tests.py index a9554a15..3ec5d862 100644 --- a/tests/pulses/loop_pulse_template_tests.py +++ b/tests/pulses/loop_pulse_template_tests.py @@ -1,6 +1,9 @@ +import sys import unittest from unittest import mock +import sympy + from qupulse.parameter_scope import DictScope from qupulse.utils.types import FrozenDict @@ -174,8 +177,27 @@ def test_integral(self) -> None: pulse = ForLoopPulseTemplate(dummy, 'i', (1, 8, 2)) expected = {'A': ExpressionScalar('Sum(t1-3.1*(1+2*i), (i, 0, 3))'), - 'B': ExpressionScalar('Sum((1+2*i), (i, 0, 3))') } - self.assertEqual(expected, pulse.integral) + 'B': ExpressionScalar('Sum((1+2*i), (i, 0, 3))')} + expected_simplified = {ch: ExpressionScalar(sympy.simplify(expr.sympified_expression)) + for ch, expr in expected.items()} + actual = pulse.integral + actual_simplified = {ch: ExpressionScalar(sympy.simplify(expr.sympified_expression)) + for ch, expr in actual.items()} + + self.assertEqual(expected_simplified, actual_simplified) + + def test_initial_values(self): + dpt = DummyPulseTemplate(initial_values={'A': 'a + 3 + i', 'B': 7}, parameter_names={'i', 'a'}) + fpt = ForLoopPulseTemplate(dpt, 'i', (1, 'n', 2)) + self.assertEqual({'A': 'a+4', 'B': 7}, fpt.initial_values) + + def test_final_values(self): + dpt = DummyPulseTemplate(final_values={'A': 'a + 3 + i', 'B': 7}, parameter_names={'i', 'a'}) + fpt = ForLoopPulseTemplate(dpt, 'i', 'n') + self.assertEqual({'A': 'a+3+Max(0, floor(n) - 1)', 'B': 7}, fpt.final_values) + + fpt_fin = ForLoopPulseTemplate(dpt, 'i', (1, 'n', 2)).final_values + self.assertEqual('a + 10', fpt_fin['A'].evaluate_symbolic({'n': 8})) def test_initial_values(self): dpt = DummyPulseTemplate(initial_values={'A': 'a + 3 + i', 'B': 7}, parameter_names={'i', 'a'}) @@ -283,12 +305,12 @@ def test_create_program_invalid_measurement_mapping(self) -> None: global_transformation=None) def test_create_program_missing_params(self) -> None: - dt = DummyPulseTemplate(parameter_names={'i'}, waveform=DummyWaveform(duration=4.0), duration='t', measurements=[('b', 2, 1)]) + dt = DummyPulseTemplate(parameter_names={'i'}, waveform=DummyWaveform(duration=4.0), duration='t', measurements=[('M', 2, 1)]) flt = ForLoopPulseTemplate(body=dt, loop_index='i', loop_range=('a', 'b', 'c'), measurements=[('A', 'alph', 1)], parameter_constraints=['c > 1']) scope = DictScope.from_kwargs(a=1, b=4) - measurement_mapping = dict(A='B') + measurement_mapping = dict(A='B', M='M') channel_mapping = dict(C='D') children = [Loop(waveform=DummyWaveform(duration=2.0))] @@ -350,6 +372,7 @@ def test_create_program_body_none(self) -> None: self.assertEqual(1, program.repetition_count) self.assertEqual([], list(program.children)) + @unittest.skipIf(sys.version_info.minor < 8, "Python 3.7 does not support changing mock call args") def test_create_program(self) -> None: dt = DummyPulseTemplate(parameter_names={'i'}, waveform=DummyWaveform(duration=4.0, defined_channels={'A'}), @@ -366,21 +389,25 @@ def test_create_program(self) -> None: global_transformation = TransformationStub() program = Loop() + self_loop = Loop(waveform=DummyWaveform(duration=1), measurements=[('B', .1, 1)]) - # inner _create_program does nothing - expected_program = Loop(measurements=[('B', .1, 1)]) + expected_program = Loop(children=[Loop( + children=[ + Loop(waveform=dt.waveform, measurements=[('b', .2, .3)]), + Loop(waveform=dt.waveform, measurements=[('b', .2, .3)]), + ], + measurements=[('B', scope['meas_param'], 1)])]) expected_create_program_kwargs = dict(measurement_mapping=measurement_mapping, channel_mapping=channel_mapping, global_transformation=global_transformation, - to_single_waveform=to_single_waveform, - parent_loop=program) + to_single_waveform=to_single_waveform) expected_create_program_calls = [mock.call(**expected_create_program_kwargs, scope=_ForLoopScope(scope, 'i', i)) for i in (1, 3)] with mock.patch.object(flt, 'validate_scope') as validate_scope: - with mock.patch.object(dt, '_create_program') as body_create_program: + with mock.patch.object(dt, '_create_program', wraps=dt._create_program) as body_create_program: with mock.patch.object(flt, 'get_measurement_windows', wraps=flt.get_measurement_windows) as get_measurement_windows: flt._internal_create_program(scope=scope, @@ -392,6 +419,11 @@ def test_create_program(self) -> None: validate_scope.assert_called_once_with(scope=scope) get_measurement_windows.assert_called_once_with(scope, measurement_mapping) + + inner_loop = program[0] + for call in expected_create_program_calls: + call.kwargs['parent_loop'] = inner_loop + self.assertEqual(body_create_program.call_args_list, expected_create_program_calls) self.assertEqual(expected_program, program) @@ -415,17 +447,15 @@ def test_create_program_append(self) -> None: to_single_waveform=set(), global_transformation=None) - self.assertEqual(3, len(program.children)) - self.assertIs(children[0], program.children[0]) - self.assertEqual(dt.waveform, program.children[1].waveform) - self.assertEqual(dt.waveform, program.children[2].waveform) - self.assertEqual(1, program.children[1].repetition_count) - self.assertEqual(1, program.children[2].repetition_count) - self.assertEqual(1, program.repetition_count) - self.assert_measurement_windows_equal({'b': ([4, 8], [1, 1]), 'B': ([2], [1])}, program.get_measurement_windows()) + expected_program = Loop(children=children + [ + Loop(children=[ + Loop(waveform=dt.waveform, measurements=[('b', 2, 1)]), + Loop(waveform=dt.waveform, measurements=[('b', 2, 1)])], + measurements=[('B', 0, 1)] + )]) - # not ensure same result as from Sequencer here - we're testing appending to an already existing parent loop - # which is a use case that does not immediately arise from using Sequencer + self.assertEqual(expected_program, program) + self.assert_measurement_windows_equal({'b': ([4, 8], [1, 1]), 'B': ([2], [1])}, program.get_measurement_windows()) class ForLoopPulseTemplateSerializationTests(SerializableTests, unittest.TestCase): diff --git a/tests/pulses/multi_channel_pulse_template_tests.py b/tests/pulses/multi_channel_pulse_template_tests.py index 84904a8d..be8d9feb 100644 --- a/tests/pulses/multi_channel_pulse_template_tests.py +++ b/tests/pulses/multi_channel_pulse_template_tests.py @@ -92,7 +92,7 @@ def test_instantiation_duration_check(self): amcpt = AtomicMultiChannelPulseTemplate(*subtemplates, duration=True) self.assertIs(amcpt.duration, subtemplates[0].duration) - with self.assertRaisesRegex(ValueError, 'duration'): + with self.assertRaisesRegex(ValueError, '[dD]uration'): amcpt.build_waveform(parameters=dict(t_1=3, t_2=3, t_3=3), channel_mapping={ch: ch for ch in 'c1 c2 c3'.split()}) @@ -207,7 +207,7 @@ def test_build_waveform(self): pt = AtomicMultiChannelPulseTemplate(*sts, parameter_constraints=['a < b']) - parameters = dict(a=2.2, b = 1.1, c=3.3) + parameters = dict(a=2.2, b = 1.1, c=3.3, t1=1.1) channel_mapping = dict() with self.assertRaises(ParameterConstraintViolation): pt.build_waveform(parameters, channel_mapping=dict()) @@ -231,7 +231,7 @@ def test_build_waveform_none(self): pt = AtomicMultiChannelPulseTemplate(*sts, parameter_constraints=['a < b']) - parameters = dict(a=2.2, b=1.1, c=3.3) + parameters = dict(a=2.2, b=1.1, c=3.3, t1=1.1) channel_mapping = dict(A=6) with self.assertRaises(ParameterConstraintViolation): # parameter constraints are checked before channel mapping is applied @@ -442,7 +442,7 @@ def test_build_waveform(self): channel_mapping = {'X': 'X', 'Y': 'K', 'Z': 'Z'} pccpt = ParallelConstantChannelPulseTemplate(template, overwritten_channels) - parameters = {'c': 1.2, 'a': 3.4} + parameters = {'c': 1.2, 'a': 3.4, 't1': template.waveform.duration} expected_overwritten_channels = {'K': 1.2, 'Z': 3.4} expected_transformation = ParallelConstantChannelTransformation(expected_overwritten_channels) expected_waveform = TransformingWaveform(template.waveform, expected_transformation) diff --git a/tests/pulses/plotting_tests.py b/tests/pulses/plotting_tests.py index 0abc56c8..344cc96d 100644 --- a/tests/pulses/plotting_tests.py +++ b/tests/pulses/plotting_tests.py @@ -5,6 +5,11 @@ import numpy +try: + import qupulse_rs +except ImportError: + qupulse_rs = None + from qupulse.pulses import ConstantPT from qupulse.pulses.plotting import PlottingNotPossibleException, render, plot from qupulse.pulses.table_pulse_template import TablePulseTemplate @@ -151,6 +156,7 @@ def test_bug_422(self): plot(pt, parameters={}) + @unittest.skipIf(qupulse_rs is not None, "Not relevant for rust code") def test_bug_422_mock(self): pt = TablePulseTemplate({'X': [(0, 1), (100, 1)]}) program = pt.create_program() diff --git a/tests/pulses/point_pulse_template_tests.py b/tests/pulses/point_pulse_template_tests.py index a259246e..4e3163d3 100644 --- a/tests/pulses/point_pulse_template_tests.py +++ b/tests/pulses/point_pulse_template_tests.py @@ -177,8 +177,8 @@ def test_build_waveform_multi_channel_same(self): (1., 0., HoldInterpolationStrategy()), (1.1, 21., LinearInterpolationStrategy())]) self.assertEqual(wf.defined_channels, {1, 'A'}) - self.assertEqual(wf._sub_waveforms[0], expected_1) - self.assertEqual(wf._sub_waveforms[1], expected_A) + self.assertEqual(wf.compare_key[0], expected_1) + self.assertEqual(wf.compare_key[1], expected_A) def test_build_waveform_multi_channel_vectorized(self): ppt = PointPulseTemplate([('t1', 'A'), @@ -196,8 +196,8 @@ def test_build_waveform_multi_channel_vectorized(self): (1., 0., HoldInterpolationStrategy()), (1.1, 20., LinearInterpolationStrategy())]) self.assertEqual(wf.defined_channels, {1, 'A'}) - self.assertEqual(wf._sub_waveforms[0], expected_1) - self.assertEqual(wf._sub_waveforms[1], expected_A) + self.assertEqual(wf.compare_key[0], expected_1) + self.assertEqual(wf.compare_key[1], expected_A) def test_build_waveform_none_channel(self): ppt = PointPulseTemplate([('t1', 'A'), diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index 0e8210bf..a26c7900 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -3,7 +3,6 @@ from unittest import mock from typing import Optional, Dict, Set, Any, Union -import sympy from qupulse.parameter_scope import Scope, DictScope from qupulse.utils.types import ChannelID @@ -12,6 +11,7 @@ from qupulse.pulses.parameters import Parameter, ConstantParameter, ParameterNotProvidedException from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform from qupulse._program._loop import Loop +from qupulse._program import ProgramBuilder, default_program_builder from qupulse._program.transformation import Transformation from qupulse._program.waveforms import TransformingWaveform @@ -71,7 +71,7 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional[Transformation], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop): + parent_loop: ProgramBuilder): raise NotImplementedError() @property @@ -90,15 +90,16 @@ def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: def final_values(self) -> Dict[ChannelID, ExpressionScalar]: raise NotImplementedError() + def __repr__(self): + return f"PulseTemplateStub(id={id(self)})" + def get_appending_internal_create_program(waveform=DummyWaveform(), always_append=False, measurements: list=None): - def internal_create_program(*, scope, parent_loop: Loop, **_): + def internal_create_program(*, scope, parent_loop: ProgramBuilder, **_): if always_append or 'append_a_child' in scope: - if measurements is not None: - parent_loop.add_measurements(measurements=measurements) - parent_loop.append_child(waveform=waveform) + parent_loop.append_leaf(waveform=waveform, measurements=measurements) return internal_create_program @@ -177,13 +178,14 @@ def test_create_program(self) -> None: with mock.patch.object(template, '_create_program', wraps=get_appending_internal_create_program(dummy_waveform)) as _create_program: - program = template.create_program(parameters=parameters, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, - to_single_waveform=to_single_waveform, - global_transformation=global_transformation, - volatile=volatile) - _create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=program) + with mock.patch('qupulse.pulses.pulse_template.default_program_builder', return_value=default_program_builder()) as _default_program_builder: + program = template.create_program(parameters=parameters, + measurement_mapping=measurement_mapping, + channel_mapping=channel_mapping, + to_single_waveform=to_single_waveform, + global_transformation=global_transformation, + volatile=volatile) + _create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=_default_program_builder.return_value) self.assertEqual(expected_program, program) self.assertEqual(previos_measurement_mapping, measurement_mapping) self.assertEqual(previous_channel_mapping, channel_mapping) @@ -233,13 +235,15 @@ def test__create_program_single_waveform(self): scope = DictScope.from_kwargs(a=1., b=2., volatile={'a'}) measurement_mapping = {'M': 'N'} channel_mapping = {'B': 'A'} - parent_loop = Loop() + + program_builder = default_program_builder() + inner_program_builder = default_program_builder() wf = DummyWaveform() single_waveform = DummyWaveform() measurements = [('m', 0, 1), ('n', 0.1, .9)] - expected_inner_program = Loop(children=[Loop(waveform=wf)], measurements=measurements) + expected_inner_program = Loop(children=[Loop(waveform=wf, measurements=measurements)]) appending_create_program = get_appending_internal_create_program(wf, measurements=measurements, @@ -250,33 +254,32 @@ def test__create_program_single_waveform(self): else: final_waveform = single_waveform - expected_program = Loop(children=[Loop(waveform=final_waveform)], - measurements=measurements) + expected_program = Loop(children=[Loop(waveform=final_waveform, measurements=measurements)]) with mock.patch.object(template, '_internal_create_program', wraps=appending_create_program) as _internal_create_program: with mock.patch('qupulse.pulses.pulse_template.to_waveform', return_value=single_waveform) as to_waveform: - template._create_program(scope=scope, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, - global_transformation=global_transformation, - to_single_waveform=to_single_waveform, - parent_loop=parent_loop) + with mock.patch('qupulse.pulses.pulse_template.default_program_builder', + return_value=inner_program_builder): + template._create_program(scope=scope, + measurement_mapping=measurement_mapping, + channel_mapping=channel_mapping, + global_transformation=global_transformation, + to_single_waveform=to_single_waveform, + parent_loop=program_builder) _internal_create_program.assert_called_once_with(scope=scope, measurement_mapping=measurement_mapping, channel_mapping=channel_mapping, global_transformation=None, to_single_waveform=to_single_waveform, - parent_loop=expected_inner_program) - + parent_loop=inner_program_builder) to_waveform.assert_called_once_with(expected_inner_program) - expected_program._measurements = set(expected_program._measurements) - parent_loop._measurements = set(parent_loop._measurements) - - self.assertEqual(expected_program, parent_loop) + self.assertEqual(expected_program, program_builder.to_program(), + f"To single waveform failed with to_single_waveform={to_single_waveform!r} and" + f" global_transformation={global_transformation!r}") def test_create_program_defaults(self) -> None: template = PulseTemplateStub(defined_channels={'A', 'B'}, parameter_names={'foo'}, measurement_names={'hugo', 'foo'}) @@ -293,8 +296,9 @@ def test_create_program_defaults(self) -> None: with mock.patch.object(template, '_internal_create_program', wraps=get_appending_internal_create_program(dummy_waveform, True)) as _internal_create_program: - program = template.create_program() - _internal_create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=program) + with mock.patch('qupulse.pulses.pulse_template.default_program_builder', return_value=default_program_builder()) as _default_program_builder: + program = template.create_program() + _internal_create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=_default_program_builder.return_value) self.assertEqual(expected_program, program) def test_create_program_channel_mapping(self): @@ -307,9 +311,9 @@ def test_create_program_channel_mapping(self): to_single_waveform=set()) with mock.patch.object(template, '_internal_create_program') as _internal_create_program: - template.create_program(channel_mapping={'A': 'C'}) - - _internal_create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=Loop()) + with mock.patch('qupulse.pulses.pulse_template.default_program_builder', return_value=default_program_builder()) as _default_program_builder: + template.create_program(channel_mapping={'A': 'C'}) + _internal_create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=_default_program_builder.return_value) def test_create_program_none(self) -> None: template = PulseTemplateStub(defined_channels={'A'}, parameter_names={'foo'}) @@ -330,11 +334,12 @@ def test_create_program_none(self) -> None: with mock.patch.object(template, '_internal_create_program') as _internal_create_program: - program = template.create_program(parameters=parameters, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, - volatile=volatile) - _internal_create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=Loop()) + with mock.patch('qupulse.pulses.pulse_template.default_program_builder', return_value=default_program_builder()) as _default_program_builder: + program = template.create_program(parameters=parameters, + measurement_mapping=measurement_mapping, + channel_mapping=channel_mapping, + volatile=volatile) + _internal_create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=_default_program_builder.return_value) self.assertIsNone(program) def test_matmul(self): @@ -376,8 +381,7 @@ def test_internal_create_program(self) -> None: channel_mapping = {'B': 'A'} program = Loop() - expected_program = Loop(children=[Loop(waveform=wf)], - measurements=[('N', 0, 5)]) + expected_program = Loop(children=[Loop(waveform=wf, measurements=[('N', 0, 5)])]) with mock.patch.object(template, 'build_waveform', return_value=wf) as build_waveform: template._internal_create_program(scope=scope, diff --git a/tests/pulses/repetition_pulse_template_tests.py b/tests/pulses/repetition_pulse_template_tests.py index 65be0932..585dd39b 100644 --- a/tests/pulses/repetition_pulse_template_tests.py +++ b/tests/pulses/repetition_pulse_template_tests.py @@ -2,6 +2,8 @@ import warnings from unittest import mock +import numpy.testing + from qupulse.parameter_scope import Scope, DictScope from qupulse.utils.types import FrozenDict @@ -120,13 +122,13 @@ def test_internal_create_program(self): to_single_waveform = {'to', 'single', 'waveform'} program = Loop() - expected_program = Loop(children=[Loop(children=[Loop(waveform=wf)], repetition_count=6)], - measurements=[('l', .1, .2)]) + expected_program = Loop(children=[Loop(children=[Loop(children=[Loop(waveform=wf)], repetition_count=6)], + measurements=[('l', .1, .2)])]) real_relevant_parameters = dict(n_rep=3, mul=2, a=0.1, b=0.2) with mock.patch.object(body, '_create_program', - wraps=get_appending_internal_create_program(wf, always_append=True)) as body_create_program: + wraps=get_appending_internal_create_program(wf, always_append=True)): with mock.patch.object(rpt, 'validate_scope') as validate_scope: with mock.patch.object(rpt, 'get_repetition_count_value', return_value=6) as get_repetition_count_value: with mock.patch.object(rpt, 'get_measurement_windows', return_value=[('l', .1, .2)]) as get_meas: @@ -138,12 +140,6 @@ def test_internal_create_program(self): parent_loop=program) self.assertEqual(program, expected_program) - body_create_program.assert_called_once_with(scope=scope, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, - global_transformation=global_transformation, - to_single_waveform=to_single_waveform, - parent_loop=program.children[0]) validate_scope.assert_called_once_with(scope) get_repetition_count_value.assert_called_once_with(scope) get_meas.assert_called_once_with(scope, measurement_mapping) @@ -163,15 +159,17 @@ def test_create_program_constant_success_measurements(self) -> None: global_transformation=None, parent_loop=program) - self.assertEqual(1, len(program.children)) - internal_loop = program[0] # type: Loop - self.assertEqual(repetitions, internal_loop.repetition_count) - - self.assertEqual(1, len(internal_loop)) - self.assertEqual((scope, measurement_mapping, channel_mapping, internal_loop), body.create_program_calls[-1]) - self.assertEqual(body.waveform, internal_loop[0].waveform) + expected_program = Loop(children=[Loop( + measurements=[("thy", 2, 2)], + children=[Loop( + children=[Loop(waveform=body.waveform, measurements=[('b', 0, 1)])], + repetition_count=repetitions + )]) + ]) - self.assert_measurement_windows_equal({'b': ([0, 2, 4], [1, 1, 1]), 'thy': ([2], [2])}, program.get_measurement_windows()) + self.assertEqual(expected_program, program) + self.assert_measurement_windows_equal({'b': ([0, 2, 4], [1, 1, 1]), 'thy': ([2], [2])}, + program.get_measurement_windows()) # done in MultiChannelProgram program.cleanup() @@ -194,16 +192,11 @@ def test_create_program_declaration_success(self) -> None: global_transformation=None, parent_loop=program) - self.assertEqual(1, program.repetition_count) - self.assertEqual(1, len(program.children)) - internal_loop = program.children[0] # type: Loop - self.assertEqual(scope[repetitions], internal_loop.repetition_count) - - self.assertEqual(1, len(internal_loop)) - self.assertEqual((scope, measurement_mapping, channel_mapping, internal_loop), - body.create_program_calls[-1]) - self.assertEqual(body.waveform, internal_loop[0].waveform) - + expected_program = Loop(children=[ + Loop(repetition_count=scope['foo'], + children=[Loop(waveform=body.waveform)]) + ]) + self.assertEqual(expected_program, program) self.assert_measurement_windows_equal({}, program.get_measurement_windows()) # ensure same result as from Sequencer @@ -219,7 +212,7 @@ def test_create_program_declaration_success_appended_measurements(self) -> None: measurement_mapping = dict(moth='fire', b='b') channel_mapping = dict(asd='f') children = [Loop(waveform=DummyWaveform(duration=0))] - program = Loop(children=children, measurements=[('a', [0], [1])], repetition_count=2) + program = Loop(children=children, measurements=[('a', 0, 1)], repetition_count=2) t._internal_create_program(scope=scope, measurement_mapping=measurement_mapping, @@ -228,22 +221,21 @@ def test_create_program_declaration_success_appended_measurements(self) -> None: global_transformation=None, parent_loop=program) - self.assertEqual(2, program.repetition_count) - self.assertEqual(2, len(program.children)) - self.assertIs(program.children[0], children[0]) - internal_loop = program.children[1] # type: Loop - self.assertEqual(scope[repetitions], internal_loop.repetition_count) - - self.assertEqual(1, len(internal_loop)) - self.assertEqual((scope, measurement_mapping, channel_mapping, internal_loop), body.create_program_calls[-1]) - self.assertEqual(body.waveform, internal_loop[0].waveform) + expected_program = Loop(children=children + [Loop( + measurements=[('fire', 0, 7.1)], + children=[Loop(repetition_count=scope['foo'], children=[Loop(waveform=body.waveform, measurements=[('b', 0, 1)])])] + )], + measurements=[('a', 0, 1)], repetition_count=2) + self.assertEqual(expected_program, program) - self.assert_measurement_windows_equal({'fire': ([0, 6], [7.1, 7.1]), - 'b': ([0, 2, 4, 6, 8, 10], [1, 1, 1, 1, 1, 1]), - 'a': ([0], [1])}, program.get_measurement_windows()) + expected_measurementt_windows = { + 'fire': ([0, 6], [7.1, 7.1]), + 'b': ([0, 2, 4, 6, 8, 10], [1, 1, 1, 1, 1, 1]), + 'a': ([0, expected_program.body_duration], [1, 1])} + numpy.testing.assert_equal(expected_measurementt_windows, program.get_measurement_windows()) - # not ensure same result as from Sequencer here - we're testing appending to an already existing parent loop - # which is a use case that does not immediately arise from using Sequencer + self.assert_measurement_windows_equal(expected_measurementt_windows, + program.get_measurement_windows()) def test_create_program_declaration_success_measurements(self) -> None: repetitions = "foo" @@ -260,15 +252,11 @@ def test_create_program_declaration_success_measurements(self) -> None: global_transformation=None, parent_loop=program) - self.assertEqual(1, program.repetition_count) - self.assertEqual(1, len(program.children)) - internal_loop = program.children[0] # type: Loop - self.assertEqual(scope[repetitions], internal_loop.repetition_count) - - self.assertEqual(1, len(internal_loop)) - self.assertEqual((scope, measurement_mapping, channel_mapping, internal_loop), body.create_program_calls[-1]) - self.assertEqual(body.waveform, internal_loop[0].waveform) - + expected_program = Loop(children=[ + Loop(measurements=[('fire', 0, scope['meas_end'])], + children=[Loop(children=[Loop(waveform=body.waveform, measurements=[('b', 0, 1)])], repetition_count=scope['foo'])]) + ]) + self.assertEqual(expected_program, program) self.assert_measurement_windows_equal({'fire': ([0], [7.1]), 'b': ([0, 2, 4], [1, 1, 1])}, program.get_measurement_windows()) def test_create_program_declaration_exceeds_bounds(self) -> None: diff --git a/tests/pulses/sequence_pulse_template_tests.py b/tests/pulses/sequence_pulse_template_tests.py index 57f8d81e..1d60eb57 100644 --- a/tests/pulses/sequence_pulse_template_tests.py +++ b/tests/pulses/sequence_pulse_template_tests.py @@ -76,7 +76,7 @@ def test_build_waveform(self): self.assertIs(pt.build_waveform_calls[0][0], parameters) self.assertIsInstance(wf, SequenceWaveform) - for wfa, wfb in zip(wf.compare_key, wfs): + for wfa, wfb in zip(wf.sequenced_waveforms, wfs): self.assertIs(wfa, wfb) def test_identifier(self) -> None: @@ -240,9 +240,9 @@ def test_internal_create_program(self): program = Loop() - expected_program = Loop(children=[Loop(waveform=wfs[0]), + expected_program = Loop(children=[Loop(children=[Loop(waveform=wfs[0]), Loop(waveform=wfs[1])], - measurements=[('l', .1, .2)]) + measurements=[('l', .1, .2)])]) with mock.patch.object(spt, 'validate_scope') as validate_scope: with mock.patch.object(spt, 'get_measurement_windows', @@ -258,8 +258,8 @@ def test_internal_create_program(self): validate_scope.assert_called_once_with(kwargs['scope']) get_measurement_windows.assert_called_once_with(kwargs['scope'], kwargs['measurement_mapping']) - create_0.assert_called_once_with(**kwargs, parent_loop=program) - create_1.assert_called_once_with(**kwargs, parent_loop=program) + # create_0.assert_called_once_with(**kwargs, parent_loop=program) + # create_1.assert_called_once_with(**kwargs, parent_loop=program) def test_create_program_internal(self) -> None: sub1 = DummyPulseTemplate(duration=3, waveform=DummyWaveform(duration=3), measurements=[('b', 1, 2)], defined_channels={'A'}) @@ -277,12 +277,14 @@ def test_create_program_internal(self) -> None: parent_loop=loop) self.assertEqual(1, loop.repetition_count) self.assertIsNone(loop.waveform) - self.assertEqual([Loop(repetition_count=1, waveform=sub1.waveform), + inner_loop, = loop.children + + self.assertEqual([Loop(repetition_count=1, waveform=sub1.waveform, measurements=[('b', 1, 2)]), Loop(repetition_count=1, waveform=sub2.waveform)], - list(loop.children)) + list(inner_loop.children)) self.assert_measurement_windows_equal({'a': ([0], [1]), 'b': ([1], [2])}, loop.get_measurement_windows()) - ### test again with inverted sequence + # test again with inverted sequence seq = SequencePulseTemplate(sub2, sub1, measurements=[('a', 0, 1)]) loop = Loop() seq._internal_create_program(scope=scope, @@ -293,9 +295,11 @@ def test_create_program_internal(self) -> None: parent_loop=loop) self.assertEqual(1, loop.repetition_count) self.assertIsNone(loop.waveform) + inner_loop, = loop.children + self.assertEqual([Loop(repetition_count=1, waveform=sub2.waveform), - Loop(repetition_count=1, waveform=sub1.waveform)], - list(loop.children)) + Loop(repetition_count=1, waveform=sub1.waveform, measurements=[('b', 1, 2)])], + list(inner_loop.children)) self.assert_measurement_windows_equal({'a': ([0], [1]), 'b': ([3], [2])}, loop.get_measurement_windows()) def test_internal_create_program_no_measurement_mapping(self) -> None: @@ -346,13 +350,12 @@ def test_internal_create_program_one_child_no_duration(self) -> None: global_transformation=None, to_single_waveform=set(), parent_loop=loop) - self.assertEqual(1, loop.repetition_count) - self.assertIsNone(loop.waveform) - self.assertEqual([Loop(repetition_count=1, waveform=sub2.waveform)], - list(loop.children)) + expected_program = Loop(children=[Loop( + children=[Loop(waveform=sub2.waveform)], + measurements=seq.measurement_declarations, + )]) + self.assertEqual(expected_program, loop) self.assert_measurement_windows_equal({'a': ([0], [1])}, loop.get_measurement_windows()) - - # MultiChannelProgram calls cleanup loop.cleanup() self.assert_measurement_windows_equal({'a': ([0], [1])}, loop.get_measurement_windows()) @@ -365,13 +368,8 @@ def test_internal_create_program_one_child_no_duration(self) -> None: global_transformation=None, to_single_waveform=set(), parent_loop=loop) - self.assertEqual(1, loop.repetition_count) - self.assertIsNone(loop.waveform) - self.assertEqual([Loop(repetition_count=1, waveform=sub2.waveform)], - list(loop.children)) + self.assertEqual(expected_program, loop) self.assert_measurement_windows_equal({'a': ([0], [1])}, loop.get_measurement_windows()) - - # MultiChannelProgram calls cleanup loop.cleanup() self.assert_measurement_windows_equal({'a': ([0], [1])}, loop.get_measurement_windows()) diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index 290912d8..56c074ad 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -7,6 +7,7 @@ """LOCAL IMPORTS""" from qupulse.parameter_scope import Scope +from qupulse._program import ProgramBuilder from qupulse._program._loop import Loop from qupulse.utils.types import MeasurementWindow, ChannelID, TimeType, time_from_float from qupulse.serialization import Serializer @@ -99,6 +100,12 @@ def compare_key(self) -> Any: else: return id(self) + def __repr__(self): + if self.sample_output is not None: + return f"{type(self).__name__}(sample_output={self.sample_output})" + else: + return f"{type(self).__name__}(id={id(self)})" + @property def measurement_windows(self): return [] @@ -245,23 +252,24 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional['Transformation'], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop) -> None: + parent_loop: ProgramBuilder) -> None: measurements = self.get_measurement_windows(scope, measurement_mapping) self.create_program_calls.append((scope, measurement_mapping, channel_mapping, parent_loop)) if self._program: - parent_loop.add_measurements(measurements) - parent_loop.append_child(waveform=self._program.waveform, children=self._program.children) + parent_loop.append_leaf(waveform=self._program.waveform, children=self._program.children, + measurements=measurements) elif self.waveform: - parent_loop.add_measurements(measurements) - parent_loop.append_child(waveform=self.waveform) + parent_loop.append_leaf(waveform=self.build_waveform(parameters=scope, channel_mapping=channel_mapping), + measurements=measurements) def build_waveform(self, parameters: Dict[str, Parameter], channel_mapping: Dict[ChannelID, ChannelID]): self.build_waveform_calls.append((parameters, channel_mapping)) + duration = self.duration.evaluate_in_scope(parameters) if self.waveform or self.waveform is None: return self.waveform - return DummyWaveform(duration=self.duration.evaluate_numeric(**parameters), defined_channels=self.defined_channels) + return DummyWaveform(duration=duration, defined_channels=self.defined_channels) def get_serialization_data(self, serializer: Optional['Serializer']=None) -> Dict[str, Any]: data = super().get_serialization_data(serializer=serializer) diff --git a/tests/pulses/table_pulse_template_tests.py b/tests/pulses/table_pulse_template_tests.py index aedc371f..9dbe5e81 100644 --- a/tests/pulses/table_pulse_template_tests.py +++ b/tests/pulses/table_pulse_template_tests.py @@ -707,7 +707,7 @@ def test_build_waveform_multi_channel(self): (5.1, 0, LinearInterpolationStrategy()))), ] - self.assertEqual(waveform._sub_waveforms, tuple(expected_waveforms)) + self.assertEqual(waveform.compare_key, tuple(expected_waveforms)) def test_build_waveform_none(self) -> None: table = TablePulseTemplate({0: [(0, 0), diff --git a/tests/utils/time_type_tests.py b/tests/utils/time_type_tests.py index 93e11832..08d062a5 100644 --- a/tests/utils/time_type_tests.py +++ b/tests/utils/time_type_tests.py @@ -68,6 +68,8 @@ def test_non_finite_float(self): qutypes.TimeType.from_float(float('nan')) def test_fraction_fallback(self): + if self.fallback_qutypes.TimeType is qutypes.RsTimeType: + self.skipTest("No fallback since rust implementation is used.") self.assertIs(fractions.Fraction, self.fallback_qutypes.TimeType._InternalType) def assert_from_fraction_works(self, time_type):