Skip to content

Commit

Permalink
Merge pull request #131 from GSTT-CSC/129-handle-multtiple-valid-expe…
Browse files Browse the repository at this point in the history
…riments-in-xnat

129 handle multtiple valid experiments in xnat
  • Loading branch information
laurencejackson authored Jan 24, 2023
2 parents d4564d1 + 7125c66 commit cce5384
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions mlops/data/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,19 @@
class DataBuilderXNAT:

def __init__(self, xnat_configuration: dict, actions: list = None, flatten_output=True, test_batch: int = -1,
num_workers: int = 1):
num_workers: int = 1, validate_data=True):
self.xnat_configuration = xnat_configuration
self.actions = actions
self.flatten_output = flatten_output
self.test_batch = test_batch
self.missing_data_log = []
self.num_workers = num_workers
self.validate_data = validate_data

self.dataset = []

def fetch_data(self):
loop = asyncio.get_event_loop()
future = asyncio.ensure_future(self.start_async_process())
loop.run_until_complete(future)
asyncio.run(self.start_async_process())

async def start_async_process(self):
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
Expand All @@ -45,8 +44,6 @@ async def start_async_process(self):
logger.info(f"Collecting XNAT project: {self.xnat_configuration['project']}")
project = session.projects[self.xnat_configuration["project"]]

dataset = []

if 0 < self.test_batch < len(project.subjects):
from random import sample
project_subjects = sample(project.subjects[:], self.test_batch)
Expand All @@ -66,7 +63,8 @@ async def start_async_process(self):
pass

# remove any items where not all actions returned a value
self.dataset = [item for item in self.dataset if len(item['data']) == len(self.actions)]
if self.validate_data:
self.dataset = [item for item in self.dataset if len(item['data']) >= len(self.actions)]

def process_subject(self, project, subject_i):
subject = project.subjects.data[subject_i.id]
Expand All @@ -85,21 +83,21 @@ def process_subject(self, project, subject_i):
# logger.debug(f"Running action: {action.__name__} on {subject.id}")
xnat_obj = action(project.subjects[subject.id])

if type(xnat_obj) == list:
if len(xnat_obj) == 0:
self.missing_data_log.append({'subject_id': subject_i.id,
'action_data': subject_i.label,
'failed_action': action})
logger.warn(f'No data found for {subject_i}: action {action} removing sample')
raise Exception
if xnat_obj is None or type(xnat_obj) == list and len(xnat_obj) == 0:
self.missing_data_log.append({'subject_id': subject_i.id,
'action_data': subject_i.label,
'failed_action': action})
logger.warn(f'No data found for {subject_i}: action {action} removing sample')
raise Exception

elif type(xnat_obj) == list:
for obj in xnat_obj:
action_data.append({'source_action': action.__name__,
'action_data': obj.uri,
'data_type': 'xnat_uri',
'data_label': data_label})

elif type(xnat_obj) == str:
else:
action_data.append({'source_action': action.__name__,
'action_data': xnat_obj,
'data_type': 'value',
Expand Down

0 comments on commit cce5384

Please sign in to comment.