Skip to content

Commit

Permalink
Merge pull request #866 from qutech/issues/865_time_reversal_program_…
Browse files Browse the repository at this point in the history
…builder

Add test and implementation for TimeReversalPulseTemplate ProgramBuilder support
  • Loading branch information
Nomos11 authored Jan 4, 2025
2 parents 6243493 + d7d8b33 commit 21469b7
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 9 deletions.
2 changes: 1 addition & 1 deletion qupulse/expressions/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def _try_to_numeric(self) -> Optional[numbers.Number]:
return None
if isinstance(self._original_expression, ALLOWED_NUMERIC_SCALAR_TYPES):
return self._original_expression
expr = self._sympified_expression
expr = self._sympified_expression.doit()
if isinstance(expr, bool):
# sympify can return bool
return expr
Expand Down
82 changes: 80 additions & 2 deletions qupulse/program/linspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ class LinSpaceNode:
def dependencies(self) -> Mapping[int, set]:
raise NotImplementedError

def reversed(self, offset: int, lengths: list):
"""Get the time reversed version of this linspace node. Since this is a non-local operation the arguments give
the context.
Args:
offset: Active iterations that are not reserved
lengths: Lengths of the currently active iterations that have to be reversed
Returns:
Time reversed version.
"""
raise NotImplementedError


@dataclass
class LinSpaceHold(LinSpaceNode):
Expand All @@ -60,13 +73,46 @@ def dependencies(self) -> Mapping[int, set]:
for idx, factors in enumerate(self.factors)
if factors}

def reversed(self, offset: int, lengths: list):
if not lengths:
return self
# If the iteration length is `n`, the starting point is shifted by `n - 1`
steps = [length - 1 for length in lengths]
bases = []
factors = []
for ch_base, ch_factors in zip(self.bases, self.factors):
if ch_factors is None or len(ch_factors) <= offset:
bases.append(ch_base)
factors.append(ch_factors)
else:
ch_reverse_base = ch_base + sum(step * factor
for factor, step in zip(ch_factors[offset:], steps))
reversed_factors = ch_factors[:offset] + tuple(-f for f in ch_factors[offset:])
bases.append(ch_reverse_base)
factors.append(reversed_factors)

if self.duration_factors is None or len(self.duration_factors) <= offset:
duration_factors = self.duration_factors
duration_base = self.duration_base
else:
duration_base = self.duration_base + sum((step * factor
for factor, step in zip(self.duration_factors[offset:], steps)), TimeType(0))
duration_factors = self.duration_factors[:offset] + tuple(-f for f in self.duration_factors[offset:])
return LinSpaceHold(tuple(bases), tuple(factors), duration_base=duration_base, duration_factors=duration_factors)


@dataclass
class LinSpaceArbitraryWaveform(LinSpaceNode):
"""This is just a wrapper to pipe arbitrary waveforms through the system."""
waveform: Waveform
channels: Tuple[ChannelID, ...]

def reversed(self, offset: int, lengths: list):
return LinSpaceArbitraryWaveform(
waveform=self.waveform.reversed(),
channels=self.channels,
)


@dataclass
class LinSpaceRepeat(LinSpaceNode):
Expand All @@ -81,6 +127,9 @@ def dependencies(self):
dependencies.setdefault(idx, set()).update(deps)
return dependencies

def reversed(self, offset: int, counts: list):
return LinSpaceRepeat(tuple(node.reversed(offset, counts) for node in reversed(self.body)), self.count)


@dataclass
class LinSpaceIter(LinSpaceNode):
Expand All @@ -100,6 +149,12 @@ def dependencies(self):
dependencies.setdefault(idx, set()).update(shortened)
return dependencies

def reversed(self, offset: int, lengths: list):
lengths.append(self.length)
reversed_iter = LinSpaceIter(tuple(node.reversed(offset, lengths) for node in reversed(self.body)), self.length)
lengths.pop()
return reversed_iter


class LinSpaceBuilder(ProgramBuilder):
"""This program builder supports efficient translation of pulse templates that use symbolic linearly
Expand Down Expand Up @@ -214,6 +269,14 @@ def with_iteration(self, index_name: str, rng: range,
if cmds:
self._stack[-1].append(LinSpaceIter(body=tuple(cmds), length=len(rng)))

@contextlib.contextmanager
def time_reversed(self) -> ContextManager['LinSpaceBuilder']:
self._stack.append([])
yield self
inner = self._stack.pop()
offset = len(self._ranges)
self._stack[-1].extend(node.reversed(offset, []) for node in reversed(inner))

def to_program(self) -> Optional[Sequence[LinSpaceNode]]:
if self._root():
return self._root()
Expand Down Expand Up @@ -414,8 +477,10 @@ def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode]) -> List[Comman


class LinSpaceVM:
def __init__(self, channels: int):
def __init__(self, channels: int,
sample_resolution: TimeType = TimeType.from_fraction(1, 2)):
self.current_values = [np.nan] * channels
self.sample_resolution = sample_resolution
self.time = TimeType(0)
self.registers = tuple({} for _ in range(channels))

Expand All @@ -428,7 +493,20 @@ def __init__(self, channels: int):

def change_state(self, cmd: Union[Set, Increment, Wait, Play]):
if isinstance(cmd, Play):
raise NotImplementedError("TODO: Implement arbitrary waveform simulation")
dt = self.sample_resolution
t = TimeType(0)
total_duration = cmd.waveform.duration
while t <= total_duration and dt > 0:
sample_time = np.array([float(t)])
values = []
for (idx, ch) in enumerate(cmd.channels):
self.current_values[idx] = values.append(cmd.waveform.get_sampled(channel=ch, sample_times=sample_time)[0])
self.history.append(
(self.time, self.current_values.copy())
)
dt = min(total_duration - t, self.sample_resolution)
self.time += dt
t += dt
elif isinstance(cmd, Wait):
self.history.append(
(self.time, self.current_values.copy())
Expand Down
3 changes: 3 additions & 0 deletions qupulse/program/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,3 +1277,6 @@ def compare_key(self) -> Hashable:

def reversed(self) -> 'Waveform':
return self._inner

def __repr__(self):
return f"ReversedWaveform(inner={self._inner!r})"
4 changes: 4 additions & 0 deletions tests/expressions/expression_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ def test_special_function_numeric_evaluation(self):

np.testing.assert_allclose(expected, result)

def test_try_to_numeric(self):
expr = ExpressionScalar('Sum(9, (x, 0, 5), (y, 0, 7))')
self.assertEqual(expr._try_to_numeric(), 9*6*8)

def test_evaluate_with_exact_rationals(self):
expr = ExpressionScalar('1 / 3')
self.assertEqual(TimeType.from_fraction(1, 3), expr.evaluate_with_exact_rationals({}))
Expand Down
2 changes: 1 addition & 1 deletion tests/program/linspace_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def assert_vm_output_almost_equal(test: TestCase, expected, actual):
test.assertEqual(t_e, t_a, f"Differing times in {idx} element")
test.assertEqual(len(vals_e), len(vals_a), f"Differing channel count in {idx} element")
for ch, (val_e, val_a) in enumerate(zip(vals_e, vals_a)):
test.assertAlmostEqual(val_e, val_a, msg=f"Differing values in {idx} element channel {ch}")
test.assertAlmostEqual(val_e, val_a, msg=f"Differing values in {idx} of {len(expected)} element channel {ch}")


class SingleRampTest(TestCase):
Expand Down
41 changes: 36 additions & 5 deletions tests/pulses/time_reversal_pulse_template_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from qupulse.pulses.time_reversal_pulse_template import TimeReversalPulseTemplate
from qupulse.utils.types import TimeType
from qupulse.expressions import ExpressionScalar

from qupulse.program.loop import LoopBuilder
from qupulse.program.linspace import LinSpaceBuilder, LinSpaceVM, to_increment_commands
from tests.pulses.sequencing_dummies import DummyPulseTemplate
from tests.serialization_tests import SerializableTests

from tests.program.linspace_tests import assert_vm_output_almost_equal

class TimeReversalPulseTemplateTests(unittest.TestCase):
def test_simple_properties(self):
Expand All @@ -29,19 +30,49 @@ def test_simple_properties(self):

self.assertEqual(reversed_pt.identifier, 'reverse')

def test_time_reversal_program(self):
def test_time_reversal_loop(self):
inner = ConstantPT(4, {'a': 3}) @ FunctionPT('sin(t)', 5, channel='a')
manual_reverse = FunctionPT('sin(5 - t)', 5, channel='a') @ ConstantPT(4, {'a': 3})
time_reversed = TimeReversalPulseTemplate(inner)

program = time_reversed.create_program()
manual_program = manual_reverse.create_program()
program = time_reversed.create_program(program_builder=LoopBuilder())
manual_program = manual_reverse.create_program(program_builder=LoopBuilder())

t, data, _ = render(program, 9 / 10)
_, manual_data, _ = render(manual_program, 9 / 10)

np.testing.assert_allclose(data['a'], manual_data['a'])

def test_time_reversal_linspace(self):
constant_pt = ConstantPT(4, {'a': '3.0 + x * 1.0 + y * -0.3'})
function_pt = FunctionPT('sin(t)', 5, channel='a')
reversed_function_pt = function_pt.with_time_reversal()

inner = (constant_pt @ function_pt).with_iteration('x', 6)
inner_manual = (reversed_function_pt @ constant_pt).with_iteration('x', (5, -1, -1))

outer = inner.with_time_reversal().with_iteration('y', 8)
outer_man = inner_manual.with_iteration('y', 8)

self.assertEqual(outer.duration, outer_man.duration)

program = outer.create_program(program_builder=LinSpaceBuilder(channels=('a',)))
manual_program = outer_man.create_program(program_builder=LinSpaceBuilder(channels=('a',)))

commands = to_increment_commands(program)
manual_commands = to_increment_commands(manual_program)
self.assertEqual(commands, manual_commands)

manual_vm = LinSpaceVM(1)
manual_vm.set_commands(manual_commands)
manual_vm.run()

vm = LinSpaceVM(1)
vm.set_commands(commands)
vm.run()

assert_vm_output_almost_equal(self, manual_vm.history, vm.history)


class TimeReversalPulseTemplateSerializationTests(unittest.TestCase, SerializableTests):
@property
Expand Down

0 comments on commit 21469b7

Please sign in to comment.