Skip to content

Commit

Permalink
Dataexample update (#1774)
Browse files Browse the repository at this point in the history
Update dataexample remote data to use zenodo_get for download

Signed-off-by: Hannah Robarts <[email protected]>
Signed-off-by: Casper da Costa-Luis <[email protected]>
Co-authored-by: Casper da Costa-Luis <[email protected]>
  • Loading branch information
hrobarts and casperdcl authored Oct 1, 2024
1 parent b10d15d commit b91c0eb
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 123 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
- BlockOperator that would return a BlockDataContainer of shape (1,1) now returns the appropriate DataContainer. BlockDataContainer direct and adjoint methods accept DataContainer as parameter (#1802).
- BlurringOperator: remove check for geometry class (old SIRF integration bug) (#1807)
- The `ZeroFunction` and `ConstantFunction` now have a Lipschitz constant of 1. (#1768)
- Update dataexample remote data download to work with windows and use zenodo_get for data download (#1774)
- Changes that break backwards compatibility:
- Merged the files `BlockGeometry.py` and `BlockDataContainer.py` in `framework` to one file `block.py`. Please use `from cil.framework import BlockGeometry, BlockDataContainer` as before (#1799)
- Bug fix in `FGP_TV` function to set the default behaviour not to enforce non-negativity (#1826).
Expand Down
102 changes: 74 additions & 28 deletions Wrappers/Python/cil/utilities/dataexample.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
import os.path
import sys
from zipfile import ZipFile
from urllib.request import urlopen
from io import BytesIO
from scipy.io import loadmat
from cil.io import NEXUSDataReader, NikonDataReader, ZEISSDataReader
from zenodo_get import zenodo_get

class DATA(object):
@classmethod
Expand All @@ -46,21 +45,15 @@ def get(cls, size=None, scale=(0,1), **kwargs):
class REMOTEDATA(DATA):

FOLDER = ''
URL = ''
FILE_SIZE = ''
ZENODO_RECORD = ''
ZIP_FILE = ''

@classmethod
def get(cls, data_dir):
return None

@classmethod
def _download_and_extract_from_url(cls, data_dir):
with urlopen(cls.URL) as response:
with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile:
zipfile.extractall(path = data_dir)

@classmethod
def download_data(cls, data_dir):
def download_data(cls, data_dir, prompt=True):
'''
Download a dataset from a remote repository
Expand All @@ -71,14 +64,18 @@ def download_data(cls, data_dir):
'''
if os.path.isdir(os.path.join(data_dir, cls.FOLDER)):
print("Dataset already exists in " + data_dir)
print("Dataset folder already exists in " + data_dir)
else:
if input("Are you sure you want to download " + cls.FILE_SIZE + " dataset from " + cls.URL + " ? (y/n)") == "y":
print('Downloading dataset from ' + cls.URL)
cls._download_and_extract_from_url(os.path.join(data_dir,cls.FOLDER))
print('Download complete')
else:
user_input = input("Are you sure you want to download {cls.ZIP_FILE} dataset from Zenodo record {cls.ZENODO_RECORD}? [Y/n]: ") if prompt else 'y'
if user_input.lower() not in ('y', 'yes'):
print('Download cancelled')
return False

zenodo_get([cls.ZENODO_RECORD, '-g', cls.ZIP_FILE, '-o', data_dir])
with ZipFile(os.path.join(data_dir, cls.ZIP_FILE), 'r') as zip_ref:
zip_ref.extractall(os.path.join(data_dir, cls.FOLDER))
os.remove(os.path.join(data_dir, cls.ZIP_FILE))
return True

class BOAT(CILDATA):
@classmethod
Expand Down Expand Up @@ -195,15 +192,21 @@ def get(cls, **kwargs):
class WALNUT(REMOTEDATA):
'''
A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516
Example
--------
>>> data_dir = 'my_PC/data_folder'
>>> dataexample.WALNUT.download_data(data_dir) # download the data
>>> dataexample.WALNUT.get(data_dir) # load the data
'''
FOLDER = 'walnut'
URL = 'https://zenodo.org/record/4822516/files/walnut.zip'
FILE_SIZE = '6.4 GB'
ZENODO_RECORD = '4822516'
ZIP_FILE = 'walnut.zip'

@classmethod
def get(cls, data_dir):
'''
A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516
Get the microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516
This function returns the raw projection data from the .txrm file
Parameters
Expand All @@ -227,15 +230,21 @@ def get(cls, data_dir):
class USB(REMOTEDATA):
'''
A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516
Example
--------
>>> data_dir = 'my_PC/data_folder'
>>> dataexample.USB.download_data(data_dir) # download the data
>>> dataexample.USB.get(data_dir) # load the data
'''
FOLDER = 'USB'
URL = 'https://zenodo.org/record/4822516/files/usb.zip'
FILE_SIZE = '3.2 GB'
ZENODO_RECORD = '4822516'
ZIP_FILE = 'usb.zip'

@classmethod
def get(cls, data_dir):
'''
A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516
Get the microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516
This function returns the raw projection data from the .txrm file
Parameters
Expand All @@ -259,15 +268,21 @@ def get(cls, data_dir):
class KORN(REMOTEDATA):
'''
A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123
Example
--------
>>> data_dir = 'my_PC/data_folder'
>>> dataexample.KORN.download_data(data_dir) # download the data
>>> dataexample.KORN.get(data_dir) # load the data
'''
FOLDER = 'korn'
URL = 'https://zenodo.org/record/6874123/files/korn.zip'
FILE_SIZE = '2.9 GB'
ZENODO_RECORD = '6874123'
ZIP_FILE = 'korn.zip'

@classmethod
def get(cls, data_dir):
'''
A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123
Get the microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123
This function returns the raw projection data from the .xtekct file
Parameters
Expand All @@ -279,6 +294,7 @@ def get(cls, data_dir):
-------
ImageData
The korn dataset
'''
filepath = os.path.join(data_dir, cls.FOLDER, 'Korn i kasse','47209 testscan korn01_recon.xtekct')
try:
Expand All @@ -293,10 +309,40 @@ class SANDSTONE(REMOTEDATA):
'''
A synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435
A small subset of the data containing selected projections and 4 slices of the reconstruction
Example
--------
>>> data_dir = 'my_PC/data_folder'
>>> dataexample.SANDSTONE.download_data(data_dir) # download the data
>>> dataexample.SANDSTONE.get(data_dir) # load the data
'''
FOLDER = 'sandstone'
URL = 'https://zenodo.org/records/4912435/files/small.zip'
FILE_SIZE = '227 MB'
ZENODO_RECORD = '4912435'
ZIP_FILE = 'small.zip'

@classmethod
def get(cls, data_dir, filename):
'''
Get the synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435
A small subset of the data containing selected projections and 4 slices of the reconstruction
Parameters
----------
data_dir: str
The path to the directory where the dataset is stored. Data can be downloaded with dataexample.SANDSTONE.download_data(data_dir)
file: str
The slices or projections to return, specify the path to the file within the data_dir
Returns
-------
ImageData
The selected sandstone dataset
'''
extension = os.path.splitext(filename)[1]
if extension == '.mat':
return loadmat(os.path.join(data_dir,filename))
raise KeyError(f"Unknown extension: {extension}")


class TestData(object):
'''Class to return test data
Expand Down
163 changes: 68 additions & 95 deletions Wrappers/Python/test/test_dataexample.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
from testclass import CCPiTestClass
import platform
import numpy as np
from unittest.mock import patch, MagicMock
from urllib import request
from unittest.mock import patch
from zipfile import ZipFile
from io import StringIO
from tempfile import NamedTemporaryFile
import uuid
from zenodo_get import zenodo_get

initialise_tests()

Expand Down Expand Up @@ -157,116 +157,89 @@ def test_load_SIMULATED_CONE_BEAM_DATA(self):
class TestRemoteData(unittest.TestCase):

def setUp(self):

self.data_list = ['WALNUT','USB','KORN','SANDSTONE']
self.shapes_path = os.path.join(dataexample.CILDATA.data_dir, dataexample.TestData.SHAPES)

def mock_urlopen(self, mock_urlopen, zipped_bytes):
mock_response = MagicMock()
mock_response.read.return_value = zipped_bytes
mock_response.__enter__.return_value = mock_response
mock_urlopen.return_value = mock_response

@unittest.skipIf(platform.system() == 'Windows', "Skip on Windows")
@patch('cil.utilities.dataexample.urlopen')
def test_unzip_remote_data(self, mock_urlopen):
'''
Test the _download_and_extract_data_from_url function correctly extracts files from a byte string
The zipped byte string is mocked using a temporary local zip file
def mock_zenodo_get(*args):
# mock zenodo_get by making a zip file containing the shapes test data when the function is called
shapes_path = os.path.join(dataexample.CILDATA.data_dir, dataexample.TestData.SHAPES)
with ZipFile(os.path.join(args[0][4], args[0][2]), mode='w') as zip_file:
zip_file.write(shapes_path, arcname=dataexample.TestData.SHAPES)


@patch('cil.utilities.dataexample.input', return_value='y')
@patch('cil.utilities.dataexample.zenodo_get', side_effect=mock_zenodo_get)
def test_download_data_input_y(self, mock_zenodo_get, input):
'''

# create a temporary zip file to test the function
with NamedTemporaryFile(suffix = '.zip') as tf:
tmp_path = os.path.dirname(tf.name)
tmp_dir = os.path.splitext(os.path.basename(tf.name))[0]
with ZipFile(tf.name, mode='w') as zip_file:
zip_file.write(self.shapes_path, arcname=dataexample.TestData.SHAPES)

with open(tf.name, 'rb') as zip_file:
zipped_bytes = zip_file.read()

self.mock_urlopen(mock_urlopen, zipped_bytes)
dataexample.REMOTEDATA._download_and_extract_from_url(os.path.join(tmp_path, tmp_dir))
Test the download_data function, when the user input is 'y' to 'are you sure you want to download data'
The user input to confirm the download is mocked as 'y'
The zip file download is mocked by creating a zip file locally
Test the download_data function correctly extracts files from the zip file
'''
# create a temporary folder in the CIL data directory
tmp_dir = os.path.join(dataexample.CILDATA.data_dir, str(uuid.uuid4()))
os.makedirs(tmp_dir)
# redirect print output
capturedOutput = StringIO()
sys.stdout = capturedOutput
for data in self.data_list:
test_func = getattr(dataexample, data)
test_func.download_data(tmp_dir)
# Test the data file exists
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, getattr(test_func, 'FOLDER'), dataexample.TestData.SHAPES)),
msg = "Download data test failed with dataset " + data)
# Test the zip file is removed
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, getattr(test_func, 'ZIP_FILE'))))
# return to standard print output
sys.stdout = sys.__stdout__
shutil.rmtree(tmp_dir)

self.assertTrue(os.path.isfile(os.path.join(tmp_path, tmp_dir, dataexample.TestData.SHAPES)))

if os.path.exists(os.path.join(tmp_path,tmp_dir)):
shutil.rmtree(os.path.join(tmp_path,tmp_dir))

@unittest.skipIf(platform.system() == 'Windows', "Skip on Windows")
@patch('cil.utilities.dataexample.input', return_value='n')
@patch('cil.utilities.dataexample.urlopen')
def test_download_data_input_n(self, mock_urlopen, input):
@patch('cil.utilities.dataexample.input', return_value='n')
@patch('cil.utilities.dataexample.zenodo_get', side_effect=mock_zenodo_get)
def test_download_data_input_n(self, mock_zenodo_get, input):
'''
Test the download_data function, when the user input is 'n' to 'are you sure you want to download data'
The zipped byte string is mocked using a temporary local zip file
'''

# create a temporary zip file to test the function
with NamedTemporaryFile(suffix = '.zip') as tf:
tmp_path = os.path.dirname(tf.name)
tmp_dir = os.path.splitext(os.path.basename(tf.name))[0]
with ZipFile(tf.name, mode='w') as zip_file:
zip_file.write(self.shapes_path, arcname=dataexample.TestData.SHAPES)

with open(tf.name, 'rb') as zip_file:
zipped_bytes = zip_file.read()

self.mock_urlopen(mock_urlopen, zipped_bytes)

# create a temporary folder in the CIL data directory
tmp_dir = os.path.join(dataexample.CILDATA.data_dir, str(uuid.uuid4()))
os.makedirs(tmp_dir)
for data in self.data_list:
# redirect print output
capturedOutput = StringIO()
sys.stdout = capturedOutput
capturedOutput = StringIO()
sys.stdout = capturedOutput
test_func = getattr(dataexample, data)
test_func.download_data(os.path.join(tmp_path, tmp_dir))
self.assertFalse(os.path.isfile(os.path.join(tmp_path, tmp_dir, test_func.FOLDER, dataexample.TestData.SHAPES)), msg = "Failed with dataset " + data)
self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n', msg = "Failed with dataset " + data)
test_func.download_data(tmp_dir)
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, getattr(test_func, 'FOLDER'), dataexample.TestData.SHAPES)),
msg = "Download dataset test failed with dataset " + data)
self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n',
msg = "Download dataset test failed with dataset " + data)
# return to standard print output
sys.stdout = sys.__stdout__

if os.path.exists(os.path.join(tmp_path,tmp_dir)):
shutil.rmtree(os.path.join(tmp_path,tmp_dir))

@unittest.skipIf(platform.system() == 'Windows', "Skip on Windows")
@patch('cil.utilities.dataexample.input', return_value='y')
@patch('cil.utilities.dataexample.urlopen')
def test_download_data_input_y(self, mock_urlopen, input):
'''
Test the download_data function, when the user input is 'y' to 'are you sure you want to download data'
The zipped byte string is mocked using a temporary local zip file
'''

with NamedTemporaryFile(suffix = '.zip') as tf:
tmp_path = os.path.dirname(tf.name)
tmp_dir = os.path.splitext(os.path.basename(tf.name))[0]
with ZipFile(tf.name, mode='w') as zip_file:
zip_file.write(self.shapes_path, arcname=dataexample.TestData.SHAPES)

with open(tf.name, 'rb') as zip_file:
zipped_bytes = zip_file.read()

self.mock_urlopen(mock_urlopen, zipped_bytes)

# redirect print output
capturedOutput = StringIO()
sys.stdout = capturedOutput
# Test the zip file IS created with prompt=False i.e. prompt not used
dataexample.WALNUT.download_data(tmp_dir, prompt=False)
# Test the data file exists
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, dataexample.WALNUT.FOLDER, dataexample.TestData.SHAPES)),
msg = "Download data test failed with dataset " + data)
# Test the zip file is removed
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, dataexample.WALNUT.ZIP_FILE)))

for data in self.data_list:
test_func = getattr(dataexample, data)
test_func.download_data(os.path.join(tmp_path, tmp_dir))
self.assertTrue(os.path.isfile(os.path.join(tmp_path, tmp_dir, test_func.FOLDER, dataexample.TestData.SHAPES)), msg = "Failed with dataset " + data)

# return to standard print output
sys.stdout = sys.__stdout__

if os.path.exists(os.path.join(tmp_path,tmp_dir)):
shutil.rmtree(os.path.join(tmp_path,tmp_dir))
shutil.rmtree(tmp_dir)


def test_download_data_bad_URL(self):
@patch('cil.utilities.dataexample.input', return_value='y')
def test_download_data_empty(self, input):
'''
Test an error is raised when _download_and_extract_from_url has an empty URL
Test an error is raised when download_data is used on an empty Zenodo record
'''
remote_data = dataexample.REMOTEDATA
remote_data.ZENODO_RECORD = 'empty'
remote_data.FOLDER = 'empty'

with self.assertRaises(ValueError):
dataexample.REMOTEDATA._download_and_extract_from_url('.')
remote_data.download_data('.')

def test_a(self):
from cil.utilities.dataexample import WALNUT

Loading

0 comments on commit b91c0eb

Please sign in to comment.