Skip to content

Commit

Permalink
Merge pull request #33 from BFedder/return_dict
Browse files Browse the repository at this point in the history
Refactoring the way data is returned in panedr.
  • Loading branch information
hmacdope authored Jun 29, 2022
2 parents 81289f1 + 2659211 commit 84bd117
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 8 deletions.
31 changes: 27 additions & 4 deletions panedr/panedr.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
import sys
import itertools
import time
import pandas
import numpy as np


#Index for the IDs of additional blocks in the energy file.
#Blocks can be added without sacrificing backward and forward
Expand Down Expand Up @@ -75,7 +76,7 @@
Enxnm = collections.namedtuple('Enxnm', 'name unit')
ENX_VERSION = 5

__all__ = ['edr_to_df']
__all__ = ['edr_to_df', 'edr_to_dict', 'read_edr']


class EDRFile(object):
Expand Down Expand Up @@ -395,14 +396,14 @@ def edr_strings(data, file_version, n):

def is_frame_magic(data):
"""Unpacks an int and checks whether it matches the EDR frame magic number
Does not roll the reading position back.
"""
magic = data.unpack_int()
return magic == -7777777


def edr_to_df(path, verbose=False):
def read_edr(path, verbose=False):
begin = time.time()
edr_file = EDRFile(str(path))
all_energies = []
Expand All @@ -427,5 +428,27 @@ def edr_to_df(path, verbose=False):
end='', file=sys.stderr)
print('\n{} frame read in {:.2f} seconds'.format(ifr, end - begin),
file=sys.stderr)

return all_energies, all_names, times


def edr_to_df(path: str, verbose: bool = False):
try:
import pandas
except ImportError:
raise ImportError("""ERROR --- pandas was not found!
pandas is required to use the `.edr_to_df()`
functionality. Try installing it using pip, e.g.:
python -m pip install pandas""")
all_energies, all_names, times = read_edr(path, verbose=verbose)
df = pandas.DataFrame(all_energies, columns=all_names, index=times)
return df


def edr_to_dict(path: str, verbose: bool = False):
all_energies, all_names, times = read_edr(path, verbose=verbose)
energy_dict = {}
for idx, name in enumerate(all_names):
energy_dict[name] = np.array(
[all_energies[frame][idx] for frame in range(len(times))])
return energy_dict
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pandas
numpy>=1.19.0
pbr
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ classifier =
test =
six
pytest
pandas =
pandas
23 changes: 20 additions & 3 deletions tests/test_edr.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@
EDR_Data = namedtuple('EDR_Data', ['df', 'xvgdata', 'xvgtime', 'xvgnames',
'xvgprec', 'edrfile', 'xvgfile'])


def test_failed_import(monkeypatch):
# Putting this test first to avoid datafiles already being loaded
errmsg = "ERROR --- pandas was not found!"
monkeypatch.setitem(sys.modules, 'pandas', None)
with pytest.raises(ImportError, match=errmsg):
panedr.edr_to_df(EDR)


@pytest.fixture(scope='module',
params=[(EDR, EDR_XVG),
(EDR_IRREGULAR, EDR_IRREGULAR_XVG),
Expand All @@ -73,7 +82,7 @@ def edr(request):
xvgtime = xvgdata[:, 0]
xvgdata = xvgdata[:, 1:]
return EDR_Data(df, xvgdata, xvgtime, xvgnames, xvgprec, edrfile, xvgfile)


class TestEdrToDf(object):
"""
Expand Down Expand Up @@ -163,10 +172,18 @@ def _assert_progress_range(self, progress, dt, start, stop, step):
assert ref_line == progress_line


def test_edr_to_dict_matches_edr_to_df():
array_dict = panedr.edr_to_dict(EDR)
ref_df = panedr.edr_to_df(EDR)
array_df = pandas.DataFrame.from_dict(array_dict).set_index(
"Time", drop=False)
assert array_df.equals(ref_df)


def read_xvg(path):
"""
Reads XVG file, returning the data, names, and precision.
The data is returned as a 2D numpy array. Column names are returned as an
array of string objects. Precision is an integer corresponding to the least
number of decimal places found, excluding the first (time) column.
Expand Down Expand Up @@ -205,7 +222,7 @@ def read_xvg(path):

def ndec(val):
"""Returns the number of decimal places of a string rep of a float
"""
try:
return len(re.split(NDEC_PATTERN, val)[1])
Expand Down

0 comments on commit 84bd117

Please sign in to comment.