diff --git a/qupulse/_program/_loop.py b/qupulse/_program/_loop.py index e0a4e6c98..9b4aecdca 100644 --- a/qupulse/_program/_loop.py +++ b/qupulse/_program/_loop.py @@ -16,7 +16,7 @@ from qupulse._program.waveforms import SequenceWaveform, RepetitionWaveform -__all__ = ['Loop', 'MultiChannelProgram', 'make_compatible'] +__all__ = ['Loop', 'MultiChannelProgram', 'make_compatible', 'to_waveform'] class Loop(Node): diff --git a/qupulse/hardware/awgs/virtual.py b/qupulse/hardware/awgs/virtual.py new file mode 100644 index 000000000..a96da4a39 --- /dev/null +++ b/qupulse/hardware/awgs/virtual.py @@ -0,0 +1,102 @@ +"""This module contains the tools to setup a virtual AWG i.e. an AWG that forwards the program into a given callback. +This is handy to setup a simulation to test qupulse pulses.""" +from typing import Tuple, Optional, Callable, Set + +import numpy as np + +from qupulse.utils.types import ChannelID, TimeType +from qupulse._program._loop import Loop, make_compatible, to_waveform +from qupulse.hardware.awgs.base import AWG + + +__all__ = ['VirtualAWG'] + + +SamplingCallback = Callable[[np.ndarray], np.ndarray] +SamplingCallback.__doc__ = """Maps an array ov times to an array of voltages. The time array has to be ordered""" + + +def _create_sampling_callbacks(program: Loop, channels, voltage_transformations) -> Tuple[float, + Tuple[SamplingCallback, ...]]: + waveform = to_waveform(program) + + duration = float(waveform.duration) + + def get_callback(channel: Optional[ChannelID], voltage_transformation): + if channel is None: + return None + else: + def sample_channel(time: np.ndarray): + return voltage_transformation(waveform.get_sampled(channel, time)) + + return sample_channel + + callbacks = [get_callback(channel, voltage_transformation) + for channel, voltage_transformation in zip(channels, voltage_transformations)] + return duration, tuple(callbacks) + + +class VirtualAWG(AWG): + """This class allows registering callbacks the given program is fed into. + + TODO: + - adaptive sample rate (requires program analysis)""" + + def __init__(self, identifier: str, channels: int): + super().__init__(identifier) + + self._programs = {} + self._current_program = None + self._channels = tuple(range(channels)) + + self._function_handle_callback = None + + @property + def num_channels(self) -> int: + return len(self._channels) + + @property + def num_markers(self) -> int: + return 0 + + def upload(self, name: str, + program: Loop, + channels: Tuple[Optional[ChannelID], ...], + markers: Tuple[Optional[ChannelID], ...], + voltage_transformation: Tuple[Optional[Callable], ...], + force: bool=False): + if name in self._programs and not force: + raise RuntimeError('Program already known') + + self._programs[name] = (program, channels, voltage_transformation) + + def remove(self, name: str): + self._programs.pop(name) + + def clear(self): + self._programs.clear() + self._current_program = None + + def arm(self, name: Optional[str]): + self._current_program = name + + @property + def programs(self) -> Set[str]: + return set(self._programs.keys()) + + @property + def sample_rate(self) -> float: + return float('nan') + + def set_function_handle_callback(self, + callback: Optional[Callable[[float, Tuple[SamplingCallback, ...]], None]]): + """When run current program is called the given callback is called with the first positional argument being the + duration and following arguments being sampling callbacks as defined above.""" + self._function_handle_callback = callback + + def run_current_program(self): + (program, channels, voltage_transformations) = self._programs[self._current_program] + + if self._function_handle_callback is not None: + duration, sample_callbacks = _create_sampling_callbacks(program, channels, voltage_transformations) + self._function_handle_callback(duration, sample_callbacks) diff --git a/qupulse/utils/tree.py b/qupulse/utils/tree.py index 58be158f5..4e50191f1 100644 --- a/qupulse/utils/tree.py +++ b/qupulse/utils/tree.py @@ -79,6 +79,9 @@ def __getitem__(self: _NodeType, *args, **kwargs) ->Union[_NodeType, List[_NodeT def __len__(self) -> int: return len(self.__children) + def __reversed__(self): + return reversed(self.__children) + def get_depth_first_iterator(self: _NodeType) -> Generator[_NodeType, None, None]: stack = [(self, self.__children)] diff --git a/tests/hardware/virtual_tests.py b/tests/hardware/virtual_tests.py new file mode 100644 index 000000000..5115229ba --- /dev/null +++ b/tests/hardware/virtual_tests.py @@ -0,0 +1,68 @@ +import unittest +from unittest import mock +import math + +import numpy as np + +from qupulse.hardware.awgs.virtual import VirtualAWG +from qupulse._program._loop import Loop + +from tests.pulses.sequencing_dummies import DummyWaveform + + +class VirtualAWGTests(unittest.TestCase): + + def test_init(self): + vawg = VirtualAWG('asd', 5) + + self.assertEqual(vawg.identifier, 'asd') + self.assertEqual(vawg.num_channels, 5) + + def test_no_markers(self): + vawg = VirtualAWG('asd', 5) + + self.assertEqual(vawg.num_markers, 0) + + def test_sample_rate(self): + vawg = VirtualAWG('asd', 5) + + self.assertTrue(math.isnan(vawg.sample_rate)) + + def test_arm(self): + name = 'prognam' + vawg = VirtualAWG('asd', 5) + + vawg.arm(name) + + self.assertEqual(vawg._current_program, name) + + def test_function_handle_callback(self): + callback = mock.MagicMock() + + vawg = VirtualAWG('asd', 3) + + vts = (lambda x: x, lambda x: 2*x, None) + + dummy_program = Loop() + dummy_waveform = DummyWaveform(sample_output={'X': np.sin(np.arange(10)), + 'Y': np.cos(np.arange(10))}, duration=42, + defined_channels={'X', 'Y'}) + vawg.upload('test', dummy_program, ('X', 'Y', None), (), vts) + vawg.arm('test') + vawg.set_function_handle_callback(callback) + with mock.patch('qupulse.hardware.awgs.virtual.to_waveform', autospec=True, return_value=dummy_waveform) as dummy_to_waveform: + vawg.run_current_program() + + dummy_to_waveform.assert_called_once_with(dummy_program) + + callback.assert_called_once() + (duration, sample_callbacks), kwargs = callback.call_args + self.assertEqual(kwargs, {}) + + self.assertEqual(duration, dummy_waveform.duration) + x, y, n = sample_callbacks + self.assertIsNone(n) + + t = np.arange(10)*1. + np.testing.assert_equal(x(t), dummy_waveform.sample_output['X']) + np.testing.assert_equal(y(t), 2*dummy_waveform.sample_output['Y'])