Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] First draft for virtual AWG #426

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion qupulse/_program/_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from qupulse._program.waveforms import SequenceWaveform, RepetitionWaveform

__all__ = ['Loop', 'MultiChannelProgram', 'make_compatible']
__all__ = ['Loop', 'MultiChannelProgram', 'make_compatible', 'to_waveform']


class Loop(Node):
Expand Down
102 changes: 102 additions & 0 deletions qupulse/hardware/awgs/virtual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""This module contains the tools to setup a virtual AWG i.e. an AWG that forwards the program into a given callback.
This is handy to setup a simulation to test qupulse pulses."""
from typing import Tuple, Optional, Callable, Set

import numpy as np

from qupulse.utils.types import ChannelID, TimeType
from qupulse._program._loop import Loop, make_compatible, to_waveform
from qupulse.hardware.awgs.base import AWG


__all__ = ['VirtualAWG']


SamplingCallback = Callable[[np.ndarray], np.ndarray]
SamplingCallback.__doc__ = """Maps an array ov times to an array of voltages. The time array has to be ordered"""


def _create_sampling_callbacks(program: Loop, channels, voltage_transformations) -> Tuple[float,
Tuple[SamplingCallback, ...]]:
waveform = to_waveform(program)

duration = float(waveform.duration)

def get_callback(channel: Optional[ChannelID], voltage_transformation):
if channel is None:
return None
else:
def sample_channel(time: np.ndarray):
return voltage_transformation(waveform.get_sampled(channel, time))

return sample_channel

callbacks = [get_callback(channel, voltage_transformation)
for channel, voltage_transformation in zip(channels, voltage_transformations)]
return duration, tuple(callbacks)


class VirtualAWG(AWG):
"""This class allows registering callbacks the given program is fed into.

TODO:
- adaptive sample rate (requires program analysis)"""

def __init__(self, identifier: str, channels: int):
super().__init__(identifier)

self._programs = {}
self._current_program = None
self._channels = tuple(range(channels))

self._function_handle_callback = None

@property
def num_channels(self) -> int:
return len(self._channels)

@property
def num_markers(self) -> int:
return 0

def upload(self, name: str,
program: Loop,
channels: Tuple[Optional[ChannelID], ...],
markers: Tuple[Optional[ChannelID], ...],
voltage_transformation: Tuple[Optional[Callable], ...],
force: bool=False):
if name in self._programs and not force:
raise RuntimeError('Program already known')

self._programs[name] = (program, channels, voltage_transformation)

def remove(self, name: str):
self._programs.pop(name)

def clear(self):
self._programs.clear()
self._current_program = None

def arm(self, name: Optional[str]):
self._current_program = name

@property
def programs(self) -> Set[str]:
return set(self._programs.keys())

@property
def sample_rate(self) -> float:
return float('nan')

def set_function_handle_callback(self,
callback: Optional[Callable[[float, Tuple[SamplingCallback, ...]], None]]):
"""When run current program is called the given callback is called with the first positional argument being the
duration and following arguments being sampling callbacks as defined above."""
self._function_handle_callback = callback

def run_current_program(self):
(program, channels, voltage_transformations) = self._programs[self._current_program]

if self._function_handle_callback is not None:
duration, sample_callbacks = _create_sampling_callbacks(program, channels, voltage_transformations)
self._function_handle_callback(duration, sample_callbacks)
3 changes: 3 additions & 0 deletions qupulse/utils/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def __getitem__(self: _NodeType, *args, **kwargs) ->Union[_NodeType, List[_NodeT
def __len__(self) -> int:
return len(self.__children)

def __reversed__(self):
return reversed(self.__children)

def get_depth_first_iterator(self: _NodeType) -> Generator[_NodeType, None, None]:
stack = [(self, self.__children)]

Expand Down
68 changes: 68 additions & 0 deletions tests/hardware/virtual_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import unittest
from unittest import mock
import math

import numpy as np

from qupulse.hardware.awgs.virtual import VirtualAWG
from qupulse._program._loop import Loop

from tests.pulses.sequencing_dummies import DummyWaveform


class VirtualAWGTests(unittest.TestCase):

def test_init(self):
vawg = VirtualAWG('asd', 5)

self.assertEqual(vawg.identifier, 'asd')
self.assertEqual(vawg.num_channels, 5)

def test_no_markers(self):
vawg = VirtualAWG('asd', 5)

self.assertEqual(vawg.num_markers, 0)

def test_sample_rate(self):
vawg = VirtualAWG('asd', 5)

self.assertTrue(math.isnan(vawg.sample_rate))

def test_arm(self):
name = 'prognam'
vawg = VirtualAWG('asd', 5)

vawg.arm(name)

self.assertEqual(vawg._current_program, name)

def test_function_handle_callback(self):
callback = mock.MagicMock()

vawg = VirtualAWG('asd', 3)

vts = (lambda x: x, lambda x: 2*x, None)

dummy_program = Loop()
dummy_waveform = DummyWaveform(sample_output={'X': np.sin(np.arange(10)),
'Y': np.cos(np.arange(10))}, duration=42,
defined_channels={'X', 'Y'})
vawg.upload('test', dummy_program, ('X', 'Y', None), (), vts)
vawg.arm('test')
vawg.set_function_handle_callback(callback)
with mock.patch('qupulse.hardware.awgs.virtual.to_waveform', autospec=True, return_value=dummy_waveform) as dummy_to_waveform:
vawg.run_current_program()

dummy_to_waveform.assert_called_once_with(dummy_program)

callback.assert_called_once()
(duration, sample_callbacks), kwargs = callback.call_args
self.assertEqual(kwargs, {})

self.assertEqual(duration, dummy_waveform.duration)
x, y, n = sample_callbacks
self.assertIsNone(n)

t = np.arange(10)*1.
np.testing.assert_equal(x(t), dummy_waveform.sample_output['X'])
np.testing.assert_equal(y(t), 2*dummy_waveform.sample_output['Y'])