Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyxdf speedup #39

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 135 additions & 15 deletions Python/pyxdf/pyxdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,33 @@ def load_xdf(filename,

"""

class XDFFormatInfo:
"""This class stores how many bytes are occupied by each part of the xdf file format.
The numbers are based on the official documentation at https://github.com/sccn/xdf/wiki/Specifications"""
class GenericChunk:
TAG_BYTES = 2

class SampleChunk:
STREAM_ID_BYTES = 4
LEN_NUM_SAMPLE_BYTES = 1
# NUM_SAMPLE_BYTES is variable depending on LEN_NUM_SAMPLE_BYTES

@staticmethod
def get_header_length(num_sample_bytes: int):
return XDFFormatInfo.GenericChunk.TAG_BYTES \
+ XDFFormatInfo.SampleChunk.STREAM_ID_BYTES \
+ XDFFormatInfo.SampleChunk.LEN_NUM_SAMPLE_BYTES \
+ num_sample_bytes

@staticmethod
def get_timestamp_count(num_payload_bytes: int, nsamples: int, samplebytes: int):
num_timestamp_bytes = num_payload_bytes - nsamples * (samplebytes + XDFFormatInfo.Sample.TIMESTAMP_EXISTS_BYTES)
return num_timestamp_bytes / XDFFormatInfo.Sample.TIMESTAMP_BYTES

class Sample:
TIMESTAMP_EXISTS_BYTES = 1
TIMESTAMP_BYTES = 8

class StreamData:
"""Temporary per-stream data."""
def __init__(self, xml):
Expand All @@ -180,6 +207,19 @@ def __init__(self, xml):
self.srate = round(float(xml['info']['nominal_srate'][0]))
# format string (int8, int16, int32, float32, double64, string)
self.fmt = xml['info']['channel_format'][0]
self.numpy_fmt = None
if self.fmt == 'int8':

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not a dict?

self.numpy_fmt = np.int8
elif self.fmt == 'int16':
self.numpy_fmt = np.int16
elif self.fmt == 'int32':
self.numpy_fmt = np.int32
elif self.fmt == 'int64':
self.numpy_fmt = np.int64
elif self.fmt == 'float32':
self.numpy_fmt = np.float32
elif self.fmt == 'double64':
self.numpy_fmt = np.float64
# list of time-stamp chunks (each an ndarray, in seconds)
self.time_stamps = []
# list of time-series chunks (each an ndarray or list of lists)
Expand All @@ -198,6 +238,10 @@ def __init__(self, xml):
self.samplebytes = self.nchns * fmt2nbytes[self.fmt]
# format string to pass to struct.unpack() to handle one sample
self.structfmt = '<%s%s' % (self.nchns, fmt2char[self.fmt])
# used to parse (parts of) chunks that are guaranteed to have no / a timestamp associated with
# every sample (the x stands for the TIMESTAMP_EXISTS byte which has to be ignored)
self.structfmt_no_timestamp = '<x%s%s' % (self.nchns, fmt2char[self.fmt])
self.structfmt_with_timestamp = '<xd%s%s' % (self.nchns, fmt2char[self.fmt])

logger.info('Importing XDF file %s...' % filename)
if not os.path.exists(filename):
Expand Down Expand Up @@ -266,7 +310,8 @@ def __init__(self, xml):
# noinspection PyBroadException
try:
# read [NumSampleBytes], [NumSamples]
nsamples = _read_varlen_int(f)
num_sample_bytes = _read_varlen_bytecount(f)
nsamples = _read_len_int(f, num_sample_bytes)
# allocate space
stamps = np.zeros((nsamples,))
if temp[StreamId].fmt == 'string':
Expand All @@ -288,19 +333,85 @@ def __init__(self, xml):
values[k][ch] = raw.decode(errors='replace')
else:
# read a sample comprised of numeric values
values = np.zeros((nsamples, temp[StreamId].nchns))
# for each sample...
for k in range(nsamples):
# read or deduce time stamp
if struct.unpack('B', f.read(1))[0]:
stamps[k] = struct.unpack('<d', f.read(8))[0]
values = np.zeros((nsamples, temp[StreamId].nchns), dtype=temp[StreamId].numpy_fmt)

num_payload_bytes = chunklen - XDFFormatInfo.SampleChunk.get_header_length(num_sample_bytes)
num_timestamps = XDFFormatInfo.SampleChunk.get_timestamp_count(
num_payload_bytes, nsamples, temp[StreamId].samplebytes)

remaining_num_timestamps = num_timestamps
remaining_num_samples = nsamples

# if only some samples are associated with a timestamp
if remaining_num_timestamps > 0 and remaining_num_timestamps != remaining_num_samples:
for k in range(nsamples):
# read or deduce time stamp
if struct.unpack('B', f.read(XDFFormatInfo.Sample.TIMESTAMP_EXISTS_BYTES))[0]:
stamps[k] = struct.unpack('<d', f.read(XDFFormatInfo.Sample.TIMESTAMP_BYTES))[0]
remaining_num_timestamps -= 1
else:
stamps[k] = (temp[StreamId].last_timestamp +
temp[StreamId].tdiff)
temp[StreamId].last_timestamp = stamps[k]
# read the values
raw = f.read(temp[StreamId].samplebytes)
values[k, :] = struct.unpack(temp[StreamId].structfmt, raw)

remaining_num_samples -= 1
if remaining_num_timestamps <= 0 or remaining_num_timestamps == remaining_num_samples:
break # if there are no timestamps left or all remaining samples have a timestamp

if remaining_num_samples > 0:
# now it's guaranteed that either no or every remaining sample is associated with
# a timestamp -> parse it all at once
all_have_timestamps = remaining_num_timestamps > 0

if all_have_timestamps:
samplesize = XDFFormatInfo.Sample.TIMESTAMP_EXISTS_BYTES \
+ XDFFormatInfo.Sample.TIMESTAMP_BYTES \
+ temp[StreamId].samplebytes
structfmt = temp[StreamId].structfmt_with_timestamp
num_dimensions = temp[StreamId].nchns + 1 # the +1 adds a column for timestamps
np_dtype = np.float64 # float64 is used because this format is used for timestamps
else: # no remaining sample is associated with a timestamp
samplesize = XDFFormatInfo.Sample.TIMESTAMP_EXISTS_BYTES + temp[StreamId].samplebytes
structfmt = temp[StreamId].structfmt_no_timestamp
num_dimensions = temp[StreamId].nchns
np_dtype = temp[StreamId].numpy_fmt

chunksize = remaining_num_samples * samplesize
index = nsamples - remaining_num_samples

raw_chunk = f.read(chunksize)
chunk_value_iterator = struct.iter_unpack(structfmt, raw_chunk)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.fromfile should be a lot faster

# flattens the iterator; np.fromiter can't handle nested iterators
chunk_value_iterator = iter(itertools.chain.from_iterable(chunk_value_iterator))
chunk_values = np.fromiter(chunk_value_iterator,
dtype=np_dtype,
count=remaining_num_samples * num_dimensions)
# converts the flat list back to a nested format
chunk_values = chunk_values.reshape((remaining_num_samples, num_dimensions))

if all_have_timestamps:
values[index:, :] = chunk_values[:, 1:]
stamps[index:] = chunk_values[:, 0]

else:
stamps[k] = (temp[StreamId].last_timestamp +
temp[StreamId].tdiff)
temp[StreamId].last_timestamp = stamps[k]
# read the values
raw = f.read(temp[StreamId].samplebytes)
values[k, :] = struct.unpack(temp[StreamId].structfmt, raw)
values[index:, :] = chunk_values

# as those samples don't have associated timestamps whe have to deduce them
if temp[StreamId].tdiff == 0:
stamps[index:] = temp[StreamId].last_timestamp
else:
new_last_timestamp = temp[StreamId].last_timestamp \
+ temp[StreamId].tdiff * remaining_num_samples
stamps[index:] = np.arange(
start=temp[StreamId].last_timestamp + temp[StreamId].tdiff,
stop=new_last_timestamp + 0.5 * temp[StreamId].tdiff,
step=temp[StreamId].tdiff) # is there a more elegant way to do this?

temp[StreamId].last_timestamp = stamps[-1]

logger.debug(' reading [%s,%s]' % (temp[StreamId].nchns,
nsamples))
# optionally send through the on_chunk function
Expand Down Expand Up @@ -343,7 +454,7 @@ def __init__(self, xml):
if stream.fmt == 'string':
stream.time_series = []
else:
stream.time_series = np.zeros((stream.nchns, 0))
stream.time_series = np.zeros((stream.nchns, 0), dtype=stream.numpy_fmt)

# perform (fault-tolerant) clock synchronization if requested
if synchronize_clocks:
Expand Down Expand Up @@ -380,7 +491,12 @@ def __init__(self, xml):

def _read_varlen_int(f):
"""Read a variable-length integer."""
nbytes = struct.unpack('B', f.read(1))[0]
nbytes = _read_varlen_bytecount(f)
return _read_len_int(f, nbytes)


def _read_len_int(f, nbytes):
"""Read a integer whose length is known."""
if nbytes == 1:
return struct.unpack('B', f.read(1))[0]
elif nbytes == 4:
Expand All @@ -391,6 +507,10 @@ def _read_varlen_int(f):
raise RuntimeError('invalid variable-length integer encountered.')


def _read_varlen_bytecount(f):
"""Read the length of the following integer."""
return struct.unpack('B', f.read(1))[0]

def _xml2dict(t):
"""Convert an attribute-less etree.Element into a dict."""
dd = defaultdict(list)
Expand Down