-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathExpSeq.m
631 lines (589 loc) · 25.6 KB
/
ExpSeq.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
%% Copyright (c) 2014-2018, Yichao Yu <[email protected]>
%
% This library is free software; you can redistribute it and/or
% modify it under the terms of the GNU Lesser General Public
% License as published by the Free Software Foundation; either
% version 3.0 of the License, or (at your option) any later version.
% This library is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
% Lesser General Public License for more details.
% You should have received a copy of the GNU Lesser General Public
% License along with this library.
classdef ExpSeq < ExpSeqBase
%% `ExpSeq` is the object representing the entire experimental sequence (root node).
% In additional to other properties and APIs provided for manipulating the
% tree structure (timing APIs in `ExpSeqBase`), this contains global information
% and API for the whole sequence including global sequence settings, e.g.
% override, start and end callbacks, channel manager etc., and APIs related
% to generating and running the sequence.
properties
%% Generation/driver related:
% Map from driver name to driver instance
drivers;
% Map from driver name to the list of channel IDs managed by the driver
driver_cids; % ::containers.Map
% Drivers sorted in the order they should be run and waited.
drivers_sorted;
generated = false;
% The total time of the sequence is cached before generation.
cached_total_time = -1;
%% Channel management:
% The channel name used when the channel is first added, indexed by channel ID.
% Only used for plotting.
orig_channel_names = {};
% The translated channel names.
% (This is unique for each channel independent of what the user uses and
% can be relied on by the backends. See `channelName`.)
channel_names = {};
% Map from channel name to channel ID.
% Include both translated name and untranslated ones as keys.
cid_cache;
% Locally disabled channels
disabled_channels;
%% Output related:
% Whether the default has been overwritten and the new default value.
% Indexed by the channel ID.
default_override = false(0);
default_override_val = [];
% Output managers indexed by channel ID.
% These are classes that can process the output (time and value)
% of a channel after the whole sequence is constructed using the global
% information of the sequence that's only available at this time.
% See `TTLMgr` for an example.
output_manager = {};
% Cache of the output manager output.
pulses_overwrite = {};
%% Running related:
% Callback to be called (without argument) before the sequence start
before_start_cbs = {};
% Callback to be called (without argument) after the sequence finishes
after_end_cbs = {};
%% IR stuff
% For dealing with code generation within the sequence.
ir_ctx;
end
properties(Constant, Access=private)
disabled = MutableRef(false);
end
methods
function self = ExpSeq(varargin)
if nargin > 1
error('Too many arguments for ExpSeq.');
elseif nargin == 1 && ~isstruct(varargin{1})
error('Constant input must be a struct.');
end
self = self@ExpSeqBase(varargin{:});
self.drivers = containers.Map();
self.driver_cids = containers.Map();
self.cid_cache = containers.Map('KeyType', 'char', 'ValueType', 'double');
self.ir_ctx = IRContext();
self.disabled_channels = containers.Map('KeyType', 'char', ...
'ValueType', 'double');
end
function res = totalTime(self)
% Note that this total time does not account for the timing
% difference caused by output managers.
if self.cached_total_time > 0
res = self.cached_total_time;
return;
end
res = totalTime@ExpSeqBase(self);
end
function addTTLMgr(self, chn, off_delay, on_delay, ...
skip_time, min_time, off_val)
if ~exist('off_val', 'var')
off_val = false;
end
chn = translateChannel(self, chn);
mgr = TTLMgr(self, chn, off_delay, on_delay, skip_time, min_time, off_val);
self.output_manager{chn} = mgr;
end
function cid = translateChannel(self, name)
%% Convert a channel name to a channel ID.
% A new ID is created if it does not exist yet.
[cid, name, inited] = getChannelId(self, name);
if inited || checkChannelDisabled(self, name)
return;
end
cpath = strsplit(name, '/');
did = cpath{1};
[driver, driver_name] = initDeviceDriver(self, did);
driver.initChannel(cid);
cur_cids = self.driver_cids(driver_name);
self.driver_cids(driver_name) = unique([cur_cids, cid]);
end
function driver = findDriver(self, driver_name)
%% Lazily create driver of the given name.
try
driver = self.drivers(driver_name);
catch
driver_func = str2func(driver_name);
driver = driver_func(self);
self.drivers(driver_name) = driver;
self.driver_cids(driver_name) = [];
end
end
function generate(self, preserve)
%% Called after the sequence is fully constructed.
% Collect global information (e.g. totla time, channel mask)
% before letting drivers generating their output (cached in driver if needed).
if ~self.generated
if ~exist('preserve', 'var')
preserve = 0;
end
self.cached_total_time = totalTime(self);
if self.config.maxLength > 0 && totalTime(self) > self.config.maxLength
error('Sequence length %f exceeds max sequence length of maxLength=%f', ...
totalTime(self), self.config.maxLength);
end
fprintf('|');
populateChnMask(self, length(self.channel_names));
for key = self.drivers.keys()
driver_name = key{:};
driver = self.drivers(driver_name);
cids = self.driver_cids(driver_name);
driver.prepare(cids);
end
for key = self.drivers.keys()
driver_name = key{:};
driver = self.drivers(driver_name);
cids = self.driver_cids(driver_name);
driver.generate(cids);
end
drivers = {};
for driver = self.drivers.values()
drivers = [drivers; {driver{:}, -driver{:}.getPriority()}];
end
if ~isempty(drivers)
drivers = sortrows(drivers, [2]);
self.drivers_sorted = drivers(:, 1);
end
self.generated = true;
if ~preserve
self.default_override = false(0);
self.default_override_val = [];
self.orig_channel_names = [];
self.channel_names = [];
self.cid_cache = [];
self.output_manager = [];
self.pulses_overwrite = [];
% NiDAC backend currently need config
self.subSeqs = [];
end
end
end
function run_async(self)
%% Start the run and return.
% Generate the sequence first if it is not done yet.
%
% Do **NOT** put anything related to runSeq in this file!!!!!!!!!!
% It messes up EVERYTHING!!!!!!!!!!!!!!!!!!!!!!
if ExpSeq.disabled.get()
return;
end
generate(self);
run_real(self);
end
function run_real(self)
%% Similar to `run_async` but more lower level.
% Assume the generation is already done and does not check the
% disable run flag.
drivers = self.drivers_sorted;
if ~isempty(self.before_start_cbs)
for cb = self.before_start_cbs
cb{:}();
end
end
for i = 1:length(drivers)
run(drivers{i});
end
end
function self = regBeforeStart(self, cb)
%% Register a callback function that will be executed before
% the sequence run.
% The callbacks will be called in the order they are registerred
% without any arguments.
self.before_start_cbs{end + 1} = cb;
end
function self = regAfterEnd(self, cb)
%% Register a callback function that will be executed after
% the sequence ends.
% The callbacks will be called in the order they are registerred
% without any arguments.
self.after_end_cbs{end + 1} = cb;
end
function waitFinish(self)
%% Wait for the sequences to finish.
% Do **NOT** put anything related to runSeq in this file!!!!!!!!!!
% It messes up EVERYTHING!!!!!!!!!!!!!!!!!!!!!!
if ExpSeq.disabled.get()
return;
end
drivers = self.drivers_sorted;
for i = 1:length(drivers)
wait(drivers{i, 1});
end
if ~isempty(self.after_end_cbs)
for cb = self.after_end_cbs
cb{:}();
end
end
end
function run(self)
%% Run the sequence (after generating) and wait for it to finish.
% Do **NOT** put anything related to runSeq in this file!!!!!!!!!!
% It messes up EVERYTHING!!!!!!!!!!!!!!!!!!!!!!
% Also, this function has to be only run_async() and then
% waitFinish() do not put any more complex logic in.
% `disabled` is fine since it doesn't mutate anything.
if ExpSeq.disabled.get()
return;
end
start_t = now() * 86400;
run_async(self);
fprintf('Running @%s\n', datestr(now(), 'yyyy/mm/dd HH:MM:SS'));
% We'll wait until this time before returning to the caller
end_after = start_t + totalTime(self) - 5e-3;
waitFinish(self);
end_t = now() * 86400;
if end_t < end_after
pause(end_after - end_t);
end
end
function self = setDefault(self, name, val)
%% Override default value in the `expConfig`.
if isnumeric(name)
cid = name;
else
cid = translateChannel(self, name);
end
self.default_override(cid) = true;
self.default_override_val(cid) = val;
end
function plot(self, varargin)
if nargin <= 1
error('Please specify at least one channel to plot.');
end
populateChnMask(self, length(self.channel_names));
cids = [];
names = {};
for i = 1:(nargin - 1)
arg = varargin{i};
if ~ischar(arg)
error('Channel name has to be a string');
end
if arg(end) == '/'
matches = regexp(arg, '^(.*[^/])/*$', 'tokens');
prefix = translateChannel(self.config, matches{1}{1});
prefix_len = size(prefix, 2);
for cid = 1:length(self.orig_channel_names)
orig_name = self.orig_channel_names{cid};
if isempty(orig_name)
continue;
end
name = translateChannel(self.config, orig_name);
if strncmp(prefix, name, prefix_len)
cids(end + 1) = cid;
names{end + 1} = orig_name;
end
end
elseif arg(1) == '~'
arg = arg(2:end);
for cid = 1:length(self.orig_channel_names)
orig_name = self.orig_channel_names{cid};
if isempty(orig_name)
continue;
end
name = translateChannel(self.config, orig_name);
if ~isempty(regexp(name, arg))
cids(end + 1) = cid;
names{end + 1} = orig_name;
end
end
else
name = translateChannel(self.config, arg);
if isKey(self.cid_cache, name)
% Look up the name without creating an ID
cid = self.cid_cache(name);
else
error('Channel %s does not exist.', arg);
end
cids(end + 1) = cid;
names{end + 1} = arg;
end
end
if size(cids, 2) == 0
error('No channel to plot.');
end
plotReal(self, cids, names);
end
function name = channelName(self, cid)
name = self.channel_names{cid};
end
function vals = getValues(self, dt, varargin)
total_t = totalTime(self);
nstep = fld(total_t, dt) + 1;
nchn = nargin - 2;
vals = zeros(nchn, nstep);
for i = 1:nchn
chn = varargin{i};
if isnumeric(chn)
scale = 1;
else
scale = chn{2};
chn = chn{1};
end
pulses = getPulseTimes(self, chn);
pidx = 1;
vidx = 1;
npulses = size(pulses, 1);
cur_value = getDefault(self, chn);
while pidx <= npulses
%% At the beginning of each loop, pidx points to the pulse to be
%% processed, vidx points to the value to be filled, cur_value is the
%% value of the channel right before vidx
%% First fill the values before the next pulse starts
pulse = pulses(pidx, :);
%% Index before next time
next_vidx = cld(pulse{1}, dt);
if next_vidx >= vidx
vals(i, vidx:next_vidx) = cur_value * scale;
end
next_time = next_vidx * dt;
vidx = next_vidx + 1;
%% Now find the last pulse that starts no later than the next point.
cur_pulse = {};
while true
%% At the beginning of each loop, pidx and pulse points to the
%% pulse to be processed, vidx should never change and should
%% points to value to be filled, cur_value is value at the end of
%% the last pulse
switch pulse{2}
case TimeType.Dirty
pulse_obj = pulse{3};
if isnumeric(pulse_obj)
cur_value = pulse_obj;
else
cur_value = calcValue(pulse_obj, pulse{7}, ...
pulse{5}, cur_value);
end
pidx = pidx + 1;
if pidx > npulses
%% End of pulses
pidx = 0;
break;
end
pulse = pulses(pidx, :);
if pulse{1} > next_time
break;
end
case TimeType.Start
pidx = pidx + 1;
if pidx > npulses
error('Unmatch pulse start and end.');
end
pulse_end = pulses(pidx, :);
if pulse_end{1} > next_time
cur_pulse = pulse;
break;
end
pulse_obj = pulse{3};
%% Forward to the end of the pulse since it is shorter than
%% our time interval.
cur_value = calcValue(pulse_obj, pulse_end{1} - pulse{4}, ...
pulse{5}, cur_value);
pidx = pidx + 1;
if pidx > npulses
%% End of pulses
pidx = 0;
break;
end
pulse = pulses(pidx, :);
if pulse{1} > next_time
break;
end
otherwise
error('Invalid pulse type.');
end
end
%% There are three possibilities when we exit the loop
%% 1. we are at the end of the pulses:
%% Just fill the rest of the sequence with the current value
%% and done for the channel.
if ~pidx
break;
end
%% 2. all the processed pulses finishes before the next time point
%% Finish the current process and run the next loop.
if isempty(cur_pulse)
continue;
end
%% 3. we've started a pulse and it continues pass the next time point
%% Calculate values for this pulse and run the next loop.
last_vidx = cld(pulse_end{1}, dt);
idxs = vidx:last_vidx;
pulse_obj = pulse{3};
vals(i, idxs) = calcValue(pulse_obj, (idxs - 1) * dt - pulse{4}, ...
pulse{5}, cur_value) * scale;
cur_value = calcValue(pulse_obj, pulse_end{1} - pulse{4}, ...
pulse{5}, cur_value);
pidx = pidx + 1;
vidx = last_vidx + 1;
end
vals(i, vidx:end) = cur_value * scale;
end
end
function res = getPulseTimes(self, cid)
res = {};
pulses = getPulses(self, cid);
for i = 1:size(pulses, 1)
pulse = pulses(i, :);
toffset = pulse{1};
step_len = pulse{2};
pulse_obj = pulse{3};
if isnumeric(pulse_obj)
res(end + 1, 1:7) = {toffset, int32(TimeType.Dirty), pulse_obj, ...
toffset, step_len, cid, 0};
else
res(end + 1, 1:7) = {toffset, int32(TimeType.Start), pulse_obj, ...
toffset, step_len, cid, 0};
res(end + 1, 1:7) = {toffset + step_len, int32(TimeType.End), pulse_obj, ...
toffset, step_len, cid, step_len};
end
end
if ~isempty(res)
res = sortrows(res, [1, 2, 7]);
end
end
function res = getPulses(self, cid)
%% Return 3-row cell array with each column being `toffset, length, pulse_obj`.
% The `pulse_obj` should be a number or a `PulseBase` (see `PulseBase::calcValue`).
% See `ExpSeqBase::appendPulses`.
% The returned value should be sorted with toffset.
%
% This must be run after `populateMask`
if length(self.pulses_overwrite) >= cid && ~isempty(self.pulses_overwrite{cid})
res = self.pulses_overwrite{cid};
return;
end
res = appendPulses(self, cid, {}, 0);
if ~isempty(res)
res = sortrows(res', 1);
end
if length(self.output_manager) >= cid && ~isempty(self.output_manager{cid})
res = processPulses(self.output_manager{cid}, res);
self.pulses_overwrite{cid} = res;
end
end
function val = getDefault(self, cid)
if length(self.default_override) >= cid && self.default_override(cid)
val = self.default_override_val(cid);
return;
end
name = channelName(self, cid);
if isKey(self.config.defaultVals, name)
val = self.config.defaultVals(name);
else
val = 0;
end
end
function disableChannel(self, name)
%% Disable channels with the prefix `name` (so `$name/...` or `name` itself)
% Disabled channels are still added to the sequence but are hidden from the backend.
name = translateChannel(self.config, name);
% This check is in principle O(M * N) in total
% where M is No of disabled channel and N is No of used channel.
% However, in practice the disable channel should only be called
% at the beginning of the sequence so this shouldn't be too bad.
% `getChannelId` guarantees that all translated names used are in `cid_cache`
for key = keys(self.cid_cache)
% This should not have false positive disabled channels
% for the same reason as `checkChannelDisabled`.
% name is always translated here.
key = key{:};
if strcmp(key, name) || startsWith(key, [name, '/'])
error('Cannot disable channel that is already initialized');
end
end
if isempty(self.disabled_channels) && ~self.G.localDisableWarned(false)
self.G.localDisableWarned = true;
warning('Channel disabled locally.');
end
self.disabled_channels(name) = 0;
end
%% name is assumed to be translated. Returns false for untranslated name.
function res = checkChannelDisabled(self, name)
% See `SeqConfig::checkChannelDisabled`
for key = keys(self.disabled_channels)
key = key{:};
if strcmp(name, key)
res = true;
return;
elseif startsWith(name, [key, '/'])
res = true;
return;
end
end
res = checkChannelDisabled(self.config, name);
end
end
methods(Access=protected)
function t = globalPath(self)
t = {};
end
end
methods(Access=private)
function [cid, name, inited] = getChannelId(self, name)
inited = true;
if isKey(self.cid_cache, name)
cid = self.cid_cache(name);
return;
end
orig_name = name;
name = translateChannel(self.config, name);
if isKey(self.cid_cache, name)
cid = self.cid_cache(name);
return;
end
cid = length(self.channel_names) + 1;
self.channel_names{cid} = name;
% This makes sure that disableChannel
% could iterate over all the translated names.
self.cid_cache(name) = cid;
if ~strcmp(name, orig_name)
self.cid_cache(orig_name) = cid;
end
if (cid > length(self.orig_channel_names) || ...
isempty(self.orig_channel_names{cid}))
self.orig_channel_names{cid} = orig_name;
inited = false;
end
end
function [driver, driver_name] = initDeviceDriver(self, did)
driver_name = self.config.pulseDrivers(did);
driver = findDriver(self, driver_name);
initDev(driver, did);
end
function plotReal(self, cids, names)
cids = num2cell(cids);
len = totalTime(self);
dt = len / 1e6;
data = getValues(self, dt, cids{:})';
ts = (1:size(data, 1)) * dt;
plot(ts, data);
xlabel('t / s');
legend(names{:});
end
end
methods(Static)
function disabler = disable(val)
ExpSeq.disabled.set(val);
% Using an anonymous function here upsets MATLAB's parser...
function cb()
ExpSeq.disabled.set(false);
end
disabler = FacyOnCleanup(@cb);
end
end
end