From 234d5e6cf1ced0af708fd0ef65d917a29d609398 Mon Sep 17 00:00:00 2001 From: Maxime Mulder Date: Thu, 17 Oct 2024 13:47:08 -0400 Subject: [PATCH] factorize subject session --- pyproject.toml | 1 + python/lib/database_lib/session_db.py | 6 + python/lib/db/model/notification_spool.py | 3 +- python/lib/db/query/project.py | 4 + .../base_pipeline.py | 109 ++---------------- .../nifti_insertion_pipeline.py | 25 ++-- python/lib/get_subject_session.py | 68 +++++++++++ python/lib/session.py | 6 + python/lib/validate_subject_info.py | 19 ++- 9 files changed, 127 insertions(+), 114 deletions(-) create mode 100644 python/lib/get_subject_session.py diff --git a/pyproject.toml b/pyproject.toml index 445f02657..6b5a79fa3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ include = [ "python/lib/config_file.py", "python/lib/env.py", "python/lib/file_system.py", + "python/lib/get_subject_session.py", "python/lib/logging.py", "python/lib/make_env.py", "python/lib/validate_subject_info.py", diff --git a/python/lib/database_lib/session_db.py b/python/lib/database_lib/session_db.py index 764015bd1..924db4c9e 100644 --- a/python/lib/database_lib/session_db.py +++ b/python/lib/database_lib/session_db.py @@ -1,9 +1,11 @@ """This class performs session table related database queries and common checks""" +from typing_extensions import deprecated __license__ = "GPLv3" +@deprecated('Use `lib.db.model.session.DbSession` instead') class SessionDB: """ This class performs database queries for session table. @@ -35,6 +37,7 @@ def __init__(self, db, verbose): self.db = db self.verbose = verbose + @deprecated('Use `lib.db.query.try_get_candidate_with_cand_id_visit_label` instead') def create_session_dict(self, cand_id, visit_label): """ Queries the session table for a particular candidate ID and visit label and returns a dictionary @@ -56,6 +59,7 @@ def create_session_dict(self, cand_id, visit_label): return results[0] if results else None + @deprecated('Use `lib.db.query.site.try_get_site_with_psc_id_visit_label` instead') def get_session_center_info(self, pscid, visit_label): """ Get site information for a given visit. @@ -77,6 +81,7 @@ def get_session_center_info(self, pscid, visit_label): return results[0] if results else None + @deprecated('Use `lib.get_subject_session.get_candidate_next_visit_number` instead') def determine_next_session_site_id_and_visit_number(self, cand_id): """ Determines the next session site and visit number based on the last session inserted for a given candidate. @@ -99,6 +104,7 @@ def determine_next_session_site_id_and_visit_number(self, cand_id): return results[0] if results else None + @deprecated('Use `lib.db.model.session.DbSession` instead') def insert_into_session(self, fields, values): """ Insert a new row in the session table using fields list as column names and values as values. diff --git a/python/lib/db/model/notification_spool.py b/python/lib/db/model/notification_spool.py index e15acdc9a..7f57f7909 100644 --- a/python/lib/db/model/notification_spool.py +++ b/python/lib/db/model/notification_spool.py @@ -25,5 +25,4 @@ class DbNotificationSpool(Base): origin : Mapped[Optional[str]] = mapped_column('Origin') active : Mapped[bool] = mapped_column('Active', YNBool) - type : Mapped['db_notification_type.DbNotificationType'] \ - = relationship('DbNotificationType') + type : Mapped['db_notification_type.DbNotificationType'] = relationship('DbNotificationType') diff --git a/python/lib/db/query/project.py b/python/lib/db/query/project.py index 2822c7223..0d884d1e7 100644 --- a/python/lib/db/query/project.py +++ b/python/lib/db/query/project.py @@ -1,3 +1,7 @@ +<<<<<<< HEAD +======= + +>>>>>>> 917e939 (factorize subject session) from sqlalchemy import select from sqlalchemy.orm import Session as Database diff --git a/python/lib/dcm2bids_imaging_pipeline_lib/base_pipeline.py b/python/lib/dcm2bids_imaging_pipeline_lib/base_pipeline.py index c41942fc9..ad4110b21 100644 --- a/python/lib/dcm2bids_imaging_pipeline_lib/base_pipeline.py +++ b/python/lib/dcm2bids_imaging_pipeline_lib/base_pipeline.py @@ -5,6 +5,7 @@ import lib.utilities from lib.database import Database from lib.database_lib.config import Config +from lib.db.query.session import try_get_session_with_cand_id_visit_label from lib.dicom_archive import DicomArchive from lib.exception.determine_subject_info_error import DetermineSubjectInfoError from lib.exception.validate_subject_info_error import ValidateSubjectInfoError @@ -192,10 +193,15 @@ def determine_study_info(self): # get the CenterID from the session table if the PSCID and visit label exists # and could be extracted from the database - self.session_obj.create_session_dict(self.subject_info.cand_id, self.subject_info.visit_label) - session_dict = self.session_obj.session_info_dict - if session_dict: - return {"CenterName": session_dict["MRI_alias"], "CenterID": session_dict["CenterID"]} + + self.session = try_get_session_with_cand_id_visit_label( + self.env.db, + self.subject_info.cand_id, + self.subject_info.visit_label, + ) + + if self.session is not None: + return {"CenterName": self.session.site.mri_alias, "CenterID": self.session.site_id} # if could not find center information based on cand_id and visit_label, use the # patient name to match it to the site alias or MRI alias @@ -223,7 +229,7 @@ def determine_scanner_info(self): self.dicom_archive_obj.tarchive_info_dict['ScannerSerialNumber'], self.dicom_archive_obj.tarchive_info_dict['ScannerModel'], self.site_dict['CenterID'], - self.session_obj.session_info_dict['ProjectID'] if self.session_obj.session_info_dict else None + self.session.project_id if self.session is not None else None, ) log_verbose(self.env, f"Found Scanner ID: {scanner_id}") @@ -248,99 +254,6 @@ def validate_subject_info(self): upload_id=self.upload_id, fields=('IsCandidateInfoValidated',), values=('0',) ) - def get_session_info(self): - """ - Creates the session info dictionary based on entries found in the session table. - """ - - self.session_obj.create_session_dict(self.subject_info.cand_id, self.subject_info.visit_label) - - if self.session_obj.session_info_dict: - log_verbose(self.env, f"Session ID for the file to insert is {self.session_obj.session_info_dict['ID']}") - - def create_session(self): - """ - Function that will create a new visit in the session table for the imaging scans after verification - that all the information necessary for the creation of the visit are present. - """ - - create_visit = self.subject_info.create_visit - - if create_visit is None: - log_error_exit( - self.env, - f"Visit {self.subject_info.visit_label} for candidate {self.subject_info.cand_id} does not exist.", - lib.exitcode.GET_SESSION_ID_FAILURE, - ) - - # check that the project ID and cohort ID refers to an existing row in project_cohort_rel table - self.session_obj.create_proj_cohort_rel_info_dict(create_visit.project_id, create_visit.cohort_id) - if not self.session_obj.proj_cohort_rel_info_dict.keys(): - log_error_exit( - self.env, - ( - f"Cannot create visit with project ID {create_visit.project_id}" - f" and cohort ID {create_visit.cohort_id}:" - f" no such association in table project_cohort_rel" - ), - lib.exitcode.CREATE_SESSION_FAILURE, - ) - - # determine the visit number and center ID for the next session to be created - center_id, visit_nb = self.determine_new_session_site_and_visit_nb() - if not center_id: - log_error_exit( - self.env, - ( - f"No center ID found for candidate {self.subject_info.cand_id}" - f", visit {self.subject_info.visit_label}" - ) - ) - else: - log_verbose(self.env, f"Set newVisitNo = {visit_nb} and center ID = {center_id}") - - # create the new visit - session_id = self.session_obj.insert_into_session( - { - 'CandID': self.subject_info.cand_id, - 'Visit_label': self.subject_info.visit_label, - 'CenterID': center_id, - 'VisitNo': visit_nb, - 'Current_stage': 'Not Started', - 'Scan_done': 'Y', - 'Submitted': 'N', - 'CohortID': create_visit.cohort_id, - 'ProjectID': create_visit.project_id - } - ) - if session_id: - self.get_session_info() - - def determine_new_session_site_and_visit_nb(self): - """ - Determines the site and visit number of the new session to be created. - - :returns: The center ID and visit number of the future new session - """ - visit_nb = 0 - center_id = 0 - - if self.subject_info.is_phantom: - center_info_dict = self.session_obj.get_session_center_info( - self.subject_info.psc_id, self.subject_info.visit_label, - ) - - if center_info_dict: - center_id = center_info_dict["CenterID"] - visit_nb = 1 - else: - center_info_dict = self.session_obj.get_next_session_site_id_and_visit_number(self.subject_info.cand_id) - if center_info_dict: - center_id = center_info_dict["CenterID"] - visit_nb = center_info_dict["newVisitNo"] - - return center_id, visit_nb - def check_if_tarchive_validated_in_db(self): """ Checks whether the DICOM archive was previously validated in the database (as per the value present diff --git a/python/lib/dcm2bids_imaging_pipeline_lib/nifti_insertion_pipeline.py b/python/lib/dcm2bids_imaging_pipeline_lib/nifti_insertion_pipeline.py index 1c5e95f77..5bde62279 100644 --- a/python/lib/dcm2bids_imaging_pipeline_lib/nifti_insertion_pipeline.py +++ b/python/lib/dcm2bids_imaging_pipeline_lib/nifti_insertion_pipeline.py @@ -11,6 +11,7 @@ from lib.dcm2bids_imaging_pipeline_lib.base_pipeline import BasePipeline from lib.exception.determine_subject_info_error import DetermineSubjectInfoError from lib.exception.validate_subject_info_error import ValidateSubjectInfoError +from lib.get_subject_session import get_subject_session from lib.logging import log_error_exit, log_verbose from lib.validate_subject_info import validate_subject_info @@ -115,9 +116,7 @@ def __init__(self, loris_getopt_obj, script_name): # --------------------------------------------------------------------------------------------- # Determine/create the session the file should be linked to # --------------------------------------------------------------------------------------------- - self.get_session_info() - if not self.session_obj.session_info_dict: - self.create_session() + self.session = get_subject_session(self.env.db, self.subject_info) # --------------------------------------------------------------------------------------------- # Determine acquisition protocol (or register into mri_protocol_violated_scans and exits) @@ -174,9 +173,9 @@ def __init__(self, loris_getopt_obj, script_name): self.exclude_violations_list = [] if not self.bypass_extra_checks: self.violations_summary = self.imaging_obj.run_extra_file_checks( - self.session_obj.session_info_dict['ProjectID'], - self.session_obj.session_info_dict['CohortID'], - self.session_obj.session_info_dict['Visit_label'], + self.session.project_id, + self.session.cohort_id, + self.session.visit_label, self.scan_type_id, self.json_file_dict ) @@ -362,15 +361,15 @@ def _determine_acquisition_protocol(self): self.json_file_dict['DeviceSerialNumber'], self.json_file_dict['ManufacturersModelName'], self.site_dict['CenterID'], - self.session_obj.session_info_dict['ProjectID'] + self.session.project_id, ) # get the list of lines in the mri_protocol table that apply to the given scan based on the protocol group protocols_list = self.imaging_obj.get_list_of_eligible_protocols_based_on_session_info( - self.session_obj.session_info_dict['ProjectID'], - self.session_obj.session_info_dict['CohortID'], - self.session_obj.session_info_dict['CenterID'], - self.session_obj.session_info_dict['Visit_label'], + self.session.project_id, + self.session.cohort_id, + self.session.site_id, + self.session.visit_label, self.scanner_id ) @@ -463,7 +462,7 @@ def _determine_new_nifti_assembly_rel_path(self): # determine NIfTI file name new_nifti_name = self._construct_nifti_filename(file_bids_entities_dict) already_inserted_filenames = self.imaging_obj.get_list_of_files_already_inserted_for_session_id( - self.session_obj.session_info_dict['ID'] + self.session.id, ) while new_nifti_name in already_inserted_filenames: file_bids_entities_dict['run'] += 1 @@ -685,7 +684,7 @@ def _register_into_files_and_parameter_file(self, nifti_rel_path): ) files_insert_info_dict = { - 'SessionID': self.session_obj.session_info_dict['ID'], + 'SessionID': self.session.id, 'File': nifti_rel_path, 'SeriesUID': scan_param['SeriesInstanceUID'] if 'SeriesInstanceUID' in scan_param.keys() else None, 'EchoTime': scan_param['EchoTime'] if 'EchoTime' in scan_param.keys() else None, diff --git a/python/lib/get_subject_session.py b/python/lib/get_subject_session.py new file mode 100644 index 000000000..9d8a33808 --- /dev/null +++ b/python/lib/get_subject_session.py @@ -0,0 +1,68 @@ +from sqlalchemy.orm import Session as Database + +from lib.dataclass.config import SubjectConfig +from lib.db.model.candidate import DbCandidate +from lib.db.model.session import DbSession +from lib.db.query.site import try_get_site_with_psc_id_visit_label +from lib.db.query.session import try_get_session_with_cand_id_visit_label + + +def get_candidate_next_visit_number(candidate: DbCandidate): + """ + Get the next visit number for a new session for a given candidate. + """ + + visit_numbers = [session.visit_number for session in candidate.sessions if session.visit_number is not None] + return max(*visit_numbers, 0) + 1 + + +def get_subject_session(db: Database, subject: SubjectConfig) -> DbSession: + """ + Get the imaging session corresponding to a given subject configuration. + + This function first looks for an adequate session in the database, and returns it if one is + found. If no session is found, this function creates a new session in the database if the + subject configuration allows it, or exits the program otherwise. + """ + + session = try_get_session_with_cand_id_visit_label(db, subject.cand_id, subject.visit_label) + if session is not None: + # TODO: Log + # f"Session ID for the file to insert is {self.session_obj.session_info_dict['ID']}" + # self.log_info(message, is_error="N", is_verbose="Y") + return session + + if subject.create_visit is None: + # TODO: Log and exit + # f"Visit {self.subject.visit_label} for candidate {self.subject.cand_id} does not exist." + # self.log_error_and_exit(message, lib.exitcode.GET_SESSION_ID_FAILURE, is_error="Y", is_verbose="N") + return exit(-1) + + if subject.is_phantom: + site = try_get_site_with_psc_id_visit_label(db, subject.psc_id, subject.visit_label) + visit_number = 1 + else: + # TODO: Get real candidate + candidate = DbCandidate() + site = candidate.registration_site + visit_number = get_candidate_next_visit_number(candidate) + + if site is None: + # message = f"No center ID found for candidate {self.subject.cand_id}, visit {self.subject.visit_label}" + return exit(-1) + + session = DbSession( + cand_id = subject.cand_id, + site_id = site.id, + visit_number = visit_number, + current_stage = 'Not Started', + scan_done = 'Y', + submitted = 'N', + project_id = subject.create_visit.project_id, + cohort_id = subject.create_visit.cohort_id, + ) + + db.add(session) + db.flush() + + return session diff --git a/python/lib/session.py b/python/lib/session.py index 81fe021a0..195b17ea1 100644 --- a/python/lib/session.py +++ b/python/lib/session.py @@ -1,5 +1,7 @@ """This class gather functions for session handling.""" +from typing_extensions import deprecated + from lib.database_lib.project_cohort_rel import ProjectCohortRel from lib.database_lib.session_db import SessionDB from lib.database_lib.site import Site @@ -126,6 +128,7 @@ def get_session_info_from_loris(self): return loris_session_info[0] if loris_session_info else None + @deprecated('Use `lib.db.query.site.try_get_site_with_psc_id_visit_label` instead') def get_session_center_info(self, pscid, visit_label): """ Get the session center information based on the PSCID and visit label of a session. @@ -140,6 +143,7 @@ def get_session_center_info(self, pscid, visit_label): """ return self.session_db_obj.get_session_center_info(pscid, visit_label) + @deprecated('Use `lib.db.query.try_get_candidate_with_cand_id_visit_label` instead') def create_session_dict(self, cand_id, visit_label): """ Creates the session information dictionary based on a candidate ID and visit label. This will populate @@ -159,6 +163,7 @@ def create_session_dict(self, cand_id, visit_label): self.cohort_id = self.session_info_dict['CohortID'] self.session_id = self.session_info_dict['ID'] + @deprecated('Use `lib.db.model.session.DbSession` instead') def insert_into_session(self, session_info_to_insert_dict): """ Insert a new row in the session table using fields list as column names and values as values. @@ -176,6 +181,7 @@ def insert_into_session(self, session_info_to_insert_dict): return self.session_id + @deprecated('Use `lib.get_subject_session.get_candidate_next_visit_number` instead') def get_next_session_site_id_and_visit_number(self, cand_id): """ Determines the next session site and visit number based on the last session inserted for a given candidate. diff --git a/python/lib/validate_subject_info.py b/python/lib/validate_subject_info.py index 65b84689b..7467fa669 100644 --- a/python/lib/validate_subject_info.py +++ b/python/lib/validate_subject_info.py @@ -4,6 +4,7 @@ from lib.config_file import SubjectInfo from lib.db.query.candidate import try_get_candidate_with_cand_id +from lib.db.query.project import try_get_project_cohort_with_project_id_cohort_id from lib.db.query.visit import try_get_visit_window_with_visit_label from lib.exception.validate_subject_info_error import ValidateSubjectInfoError @@ -29,12 +30,28 @@ def validate_subject_info(db: Database, subject_info: SubjectInfo): ) visit_window = try_get_visit_window_with_visit_label(db, subject_info.visit_label) - if visit_window is None and subject_info.create_visit is not None: + if visit_window is not None: + return + + if subject_info.create_visit is None: validate_subject_error( subject_info, f'Visit label \'{subject_info.visit_label}\' does not exist in the database (table `Visit_Windows`).' ) + project_id = subject_info.create_visit.project_id + cohort_id = subject_info.create_visit.cohort_id + + project_cohort = try_get_project_cohort_with_project_id_cohort_id(db, project_id, cohort_id) + if project_cohort is None: + validate_subject_error( + subject_info, + ( + f'Cannot create a session with project ID {project_id} and cohort ID {cohort_id}.\n' + f'This project and this cohort are not associated in the database (table `project_cohort_rel`).' + ), + ) + def validate_subject_error(subject_info: SubjectInfo, message: str) -> Never: raise ValidateSubjectInfoError(f'Validation error for subject \'{subject_info.name}\'.\n{message}')