diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index daf1118..d69971f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -67,7 +67,7 @@ jobs: with: coverageFile: coverage.xml token: ${{ secrets.GITHUB_TOKEN }} - thresholdAll: .97 + thresholdAll: .98 thresholdNew: 1 thresholdModified: 1 diff --git a/cumulus_etl/etl/cli.py b/cumulus_etl/etl/cli.py index 3fe38b6..1d2504b 100644 --- a/cumulus_etl/etl/cli.py +++ b/cumulus_etl/etl/cli.py @@ -192,15 +192,15 @@ def print_config( def handle_completion_args( - args: argparse.Namespace, loader: loaders.Loader + args: argparse.Namespace, loader_results: loaders.LoaderResults ) -> (str, datetime.datetime): """Returns (group_name, datetime)""" # Grab completion options from CLI or loader - export_group_name = args.export_group or loader.group_name + export_group_name = args.export_group or loader_results.group_name export_datetime = ( datetime.datetime.fromisoformat(args.export_timestamp) if args.export_timestamp - else loader.export_datetime + else loader_results.export_datetime ) # Disable entirely if asked to @@ -267,14 +267,14 @@ async def etl_main(args: argparse.Namespace) -> None: ) # Pull down resources from any remote location (like s3), convert from i2b2, or do a bulk export - loaded_dir = await config_loader.load_all(list(required_resources)) + loader_results = await config_loader.load_all(list(required_resources)) # Establish the group name and datetime of the loaded dataset (from CLI args or Loader) - export_group_name, export_datetime = handle_completion_args(args, config_loader) + export_group_name, export_datetime = handle_completion_args(args, loader_results) # If *any* of our tasks need bulk MS de-identification, run it if any(t.needs_bulk_deid for t in selected_tasks): - loaded_dir = await deid.Scrubber.scrub_bulk_data(loaded_dir.name) + loader_results.directory = await deid.Scrubber.scrub_bulk_data(loader_results.path) else: print("Skipping bulk de-identification.") print("These selected tasks will de-identify resources as they are processed.") @@ -282,7 +282,7 @@ async def etl_main(args: argparse.Namespace) -> None: # Prepare config for jobs config = JobConfig( args.dir_input, - loaded_dir.name, + loader_results.path, args.dir_output, args.dir_phi, args.input_format, @@ -296,6 +296,7 @@ async def etl_main(args: argparse.Namespace) -> None: tasks=[t.name for t in selected_tasks], export_group_name=export_group_name, export_datetime=export_datetime, + deleted_ids=loader_results.deleted_ids, ) common.write_json(config.path_config(), config.as_json(), indent=4) diff --git a/cumulus_etl/etl/config.py b/cumulus_etl/etl/config.py index 9195ae2..2491876 100644 --- a/cumulus_etl/etl/config.py +++ b/cumulus_etl/etl/config.py @@ -33,6 +33,7 @@ def __init__( tasks: list[str] | None = None, export_group_name: str | None = None, export_datetime: datetime.datetime | None = None, + deleted_ids: dict[str, set[str]] | None = None, ): self._dir_input_orig = dir_input_orig self.dir_input = dir_input_deid @@ -50,6 +51,7 @@ def __init__( self.tasks = tasks or [] self.export_group_name = export_group_name self.export_datetime = export_datetime + self.deleted_ids = deleted_ids or {} # initialize format class self._output_root = store.Root(self._dir_output, create=True) diff --git a/cumulus_etl/etl/convert/cli.py b/cumulus_etl/etl/convert/cli.py index ffc2957..6ee7dd5 100644 --- a/cumulus_etl/etl/convert/cli.py +++ b/cumulus_etl/etl/convert/cli.py @@ -36,6 +36,20 @@ def make_batch( return formats.Batch(rows, groups=groups, schema=schema) +def convert_table_metadata( + meta_path: str, + formatter: formats.Format, +) -> None: + try: + meta = common.read_json(meta_path) + except (FileNotFoundError, PermissionError): + return + + # Only one metadata field currently: deleted IDs + deleted = meta.get("deleted", []) + formatter.delete_records(set(deleted)) + + def convert_folder( input_root: store.Root, *, @@ -66,6 +80,7 @@ def convert_folder( formatter.write_records(batch) progress.update(progress_task, advance=1) + convert_table_metadata(f"{table_input_dir}/{table_name}.meta", formatter) formatter.finalize() progress.update(progress_task, advance=1) @@ -117,14 +132,15 @@ def convert_completion( def copy_job_configs(input_root: store.Root, output_root: store.Root) -> None: with tempfile.TemporaryDirectory() as tmpdir: - job_config_path = input_root.joinpath("JobConfig") + job_config_path = input_root.joinpath("JobConfig/") # Download input dir if it's not local if input_root.protocol != "file": - input_root.get(job_config_path, tmpdir, recursive=True) - job_config_path = os.path.join(tmpdir, "JobConfig") + new_location = os.path.join(tmpdir, "JobConfig/") + input_root.get(job_config_path, new_location, recursive=True) + job_config_path = new_location - output_root.put(job_config_path, output_root.path, recursive=True) + output_root.put(job_config_path, output_root.joinpath("JobConfig/"), recursive=True) def walk_tree( diff --git a/cumulus_etl/etl/tasks/base.py b/cumulus_etl/etl/tasks/base.py index f4b39ae..4aa9e99 100644 --- a/cumulus_etl/etl/tasks/base.py +++ b/cumulus_etl/etl/tasks/base.py @@ -143,11 +143,15 @@ async def run(self) -> list[config.JobSummary]: with self._indeterminate_progress(progress, "Finalizing"): # Ensure that we touch every output table (to create them and/or to confirm schema). - # Consider case of Medication for an EHR that only has inline Medications inside MedicationRequest. - # The Medication table wouldn't get created otherwise. Plus this is a good place to push any schema - # changes. (The reason it's nice if the table & schema exist is so that downstream SQL can be dumber.) + # Consider case of Medication for an EHR that only has inline Medications inside + # MedicationRequest. The Medication table wouldn't get created otherwise. + # Plus this is a good place to push any schema changes. (The reason it's nice if + # the table & schema exist is so that downstream SQL can be dumber.) self._touch_remaining_tables() + # If the input data indicates we should delete some IDs, do that here. + self._delete_requested_ids() + # Mark this group & resource combo as complete self._update_completion_table() @@ -228,6 +232,32 @@ def _touch_remaining_tables(self): # just write an empty dataframe (should be fast) self._write_one_table_batch([], table_index, 0) + def _delete_requested_ids(self): + """ + Deletes IDs that have been marked for deletion. + + Formatters are expected to already exist when this is called. + + This usually happens via the `deleted` array from a bulk export. + Which clients usually drop in a deleted/ folder in the download directory. + But in our case, that's abstracted away into a JobConfig.deleted_ids dictionary. + """ + for index, output in enumerate(self.outputs): + resource = output.get_resource_type(self) + if not resource or resource.lower() != output.get_name(self): + # Only delete from the main table for the resource + continue + + deleted_ids = self.task_config.deleted_ids.get(resource, set()) + if not deleted_ids: + continue + + deleted_ids = { + self.scrubber.codebook.fake_id(resource, x, caching_allowed=False) + for x in deleted_ids + } + self.formatters[index].delete_records(deleted_ids) + def _update_completion_table(self) -> None: # TODO: what about empty sets - do we assume the export gave 0 results or skip it? # Is there a difference we could notice? (like empty input file vs no file at all) diff --git a/cumulus_etl/formats/base.py b/cumulus_etl/formats/base.py index 8f56dd4..7c56f8d 100644 --- a/cumulus_etl/formats/base.py +++ b/cumulus_etl/formats/base.py @@ -74,6 +74,14 @@ def _write_one_batch(self, batch: Batch) -> None: :param batch: the batch of data """ + @abc.abstractmethod + def delete_records(self, ids: set[str]) -> None: + """ + Deletes all mentioned IDs from the table. + + :param ids: all IDs to remove + """ + def finalize(self) -> None: """ Performs any necessary cleanup after all batches have been written. diff --git a/cumulus_etl/formats/batched_files.py b/cumulus_etl/formats/batched_files.py index 9369a19..86e9f0c 100644 --- a/cumulus_etl/formats/batched_files.py +++ b/cumulus_etl/formats/batched_files.py @@ -83,3 +83,14 @@ def _write_one_batch(self, batch: Batch) -> None: full_path = self.dbroot.joinpath(f"{self.dbname}.{self._index:03}.{self.suffix}") self.write_format(batch, full_path) self._index += 1 + + def delete_records(self, ids: set[str]) -> None: + """ + Deletes the given IDs. + + Though this is a no-op for batched file outputs, since: + - we guarantee the output folder is empty at the start + - the spec says deleted IDs won't overlap with output IDs + + But subclasses may still want to write these to disk to preserve the metadata. + """ diff --git a/cumulus_etl/formats/deltalake.py b/cumulus_etl/formats/deltalake.py index 1e8fc60..06aa286 100644 --- a/cumulus_etl/formats/deltalake.py +++ b/cumulus_etl/formats/deltalake.py @@ -91,8 +91,6 @@ def initialize_class(cls, root: store.Root) -> None: def _write_one_batch(self, batch: Batch) -> None: """Writes the whole dataframe to a delta lake""" with self.batch_to_spark(batch) as updates: - if updates is None: - return delta_table = self.update_delta_table(updates, groups=batch.groups) delta_table.generate("symlink_format_manifest") @@ -131,16 +129,25 @@ def update_delta_table( return table - def finalize(self) -> None: - """Performs any necessary cleanup after all batches have been written""" - full_path = self._table_path(self.dbname) + def delete_records(self, ids: set[str]) -> None: + """Deletes the given IDs.""" + if not ids: + return + + table = self._load_table() + if not table: + return try: - table = delta.DeltaTable.forPath(self.spark, full_path) - except AnalysisException: - return # if the table doesn't exist because we didn't write anything, that's fine - just bail + id_list = "', '".join(ids) + table.delete(f"id in ('{id_list}')") except Exception: - logging.exception("Could not finalize Delta Lake table %s", self.dbname) + logging.exception("Could not delete IDs from Delta Lake table %s", self.dbname) + + def finalize(self) -> None: + """Performs any necessary cleanup after all batches have been written""" + table = self._load_table() + if not table: return try: @@ -154,6 +161,19 @@ def _table_path(self, dbname: str) -> str: # hadoop uses the s3a: scheme instead of s3: return self.root.joinpath(dbname).replace("s3://", "s3a://") + def _load_table(self) -> delta.DeltaTable | None: + full_path = self._table_path(self.dbname) + + try: + return delta.DeltaTable.forPath(self.spark, full_path) + except AnalysisException: + # The table likely doesn't exist. + # Which can be normal if we didn't write anything yet, that's fine - just bail. + return None + except Exception: + logging.exception("Could not load Delta Lake table %s", self.dbname) + return None + @staticmethod def _get_update_condition(schema: pyspark.sql.types.StructType) -> str | None: """ @@ -214,7 +234,7 @@ def _configure_fs(root: store.Root, spark: pyspark.sql.SparkSession): spark.conf.set("fs.s3a.endpoint.region", region_name) @contextlib.contextmanager - def batch_to_spark(self, batch: Batch) -> pyspark.sql.DataFrame | None: + def batch_to_spark(self, batch: Batch) -> pyspark.sql.DataFrame: """Transforms a batch to a spark DF""" # This is the quick and dirty way - write batch to parquet with pyarrow and read it back. # But a more direct way would be to convert the pyarrow schema to a pyspark schema and just diff --git a/cumulus_etl/formats/ndjson.py b/cumulus_etl/formats/ndjson.py index 22358eb..a7ea5fc 100644 --- a/cumulus_etl/formats/ndjson.py +++ b/cumulus_etl/formats/ndjson.py @@ -35,3 +35,22 @@ def write_format(self, batch: Batch, path: str) -> None: # This is mostly used in tests and debugging, so we'll write out sparse files (no null columns) common.write_rows_to_ndjson(path, batch.rows, sparse=True) + + def table_metadata_path(self) -> str: + return self.dbroot.joinpath(f"{self.dbname}.meta") # no batch number + + def read_table_metadata(self) -> dict: + try: + return common.read_json(self.table_metadata_path()) + except (FileNotFoundError, PermissionError): + return {} + + def write_table_metadata(self, metadata: dict) -> None: + self.root.makedirs(self.dbroot.path) + common.write_json(self.table_metadata_path(), metadata, indent=2) + + def delete_records(self, ids: set[str]) -> None: + # Read and write back table metadata, with the addition of these new deleted IDs + meta = self.read_table_metadata() + meta.setdefault("deleted", []).extend(sorted(ids)) + self.write_table_metadata(meta) diff --git a/cumulus_etl/loaders/__init__.py b/cumulus_etl/loaders/__init__.py index 46a372a..415f241 100644 --- a/cumulus_etl/loaders/__init__.py +++ b/cumulus_etl/loaders/__init__.py @@ -1,5 +1,5 @@ """Public API for loaders""" -from .base import Loader +from .base import Loader, LoaderResults from .fhir.ndjson_loader import FhirNdjsonLoader from .i2b2.loader import I2b2Loader diff --git a/cumulus_etl/loaders/base.py b/cumulus_etl/loaders/base.py index c2f555a..f773df0 100644 --- a/cumulus_etl/loaders/base.py +++ b/cumulus_etl/loaders/base.py @@ -1,10 +1,33 @@ """Base abstract loader""" import abc +import dataclasses +import datetime from cumulus_etl import common, store +@dataclasses.dataclass(kw_only=True) +class LoaderResults: + """Bundles results of a load request""" + + # Where loaded files reside on disk (use .path for convenience) + directory: common.Directory + + @property + def path(self) -> str: + return self.directory.name + + # Completion tracking values - noting an export group name for this bundle of data + # and the time when it was exported ("transactionTime" in bulk-export terms). + group_name: str | None = None + export_datetime: datetime.datetime | None = None + + # A list of resource IDs that should be deleted from the output tables. + # This is a map of resource -> set of IDs like {"Patient": {"A", "B"}} + deleted_ids: dict[str, set[str]] = dataclasses.field(default_factory=dict) + + class Loader(abc.ABC): """ An abstraction for how to load FHIR input @@ -21,12 +44,8 @@ def __init__(self, root: store.Root): """ self.root = root - # Public properties (potentially set when loading) for reporting back to caller - self.group_name = None - self.export_datetime = None - @abc.abstractmethod - async def load_all(self, resources: list[str]) -> common.Directory: + async def load_all(self, resources: list[str]) -> LoaderResults: """ Loads the listed remote resources and places them into a local folder as FHIR ndjson diff --git a/cumulus_etl/loaders/fhir/ndjson_loader.py b/cumulus_etl/loaders/fhir/ndjson_loader.py index f211c92..d304b44 100644 --- a/cumulus_etl/loaders/fhir/ndjson_loader.py +++ b/cumulus_etl/loaders/fhir/ndjson_loader.py @@ -37,11 +37,11 @@ def __init__( self.until = until self.resume = resume - async def load_all(self, resources: list[str]) -> common.Directory: + async def load_all(self, resources: list[str]) -> base.LoaderResults: # Are we doing a bulk FHIR export from a server? if self.root.protocol in ["http", "https"]: - loaded_dir = await self.load_from_bulk_export(resources) - input_root = store.Root(loaded_dir.name) + results = await self.load_from_bulk_export(resources) + input_root = store.Root(results.path) else: if self.export_to or self.since or self.until or self.resume: errors.fatal( @@ -49,18 +49,21 @@ async def load_all(self, resources: list[str]) -> common.Directory: errors.ARGS_CONFLICT, ) + results = base.LoaderResults(directory=self.root.path) input_root = self.root # Parse logs for export information try: parser = BulkExportLogParser(input_root) - self.group_name = parser.group_name - self.export_datetime = parser.export_datetime + results.group_name = parser.group_name + results.export_datetime = parser.export_datetime except BulkExportLogParser.LogParsingError: # Once we require group name & export datetime, we should warn about this. # For now, just ignore any errors. pass + results.deleted_ids = self.read_deleted_ids(input_root) + # Copy the resources we need from the remote directory (like S3 buckets) to a local one. # # We do this even if the files are local, because the next step in our pipeline is the MS deid tool, @@ -75,11 +78,13 @@ async def load_all(self, resources: list[str]) -> common.Directory: filenames = common.ls_resources(input_root, set(resources), warn_if_empty=True) for filename in filenames: input_root.get(filename, f"{tmpdir.name}/") - return tmpdir + results.directory = tmpdir + + return results async def load_from_bulk_export( self, resources: list[str], prefer_url_resources: bool = False - ) -> common.Directory: + ) -> base.LoaderResults: """ Performs a bulk export and drops the results in an export dir. @@ -101,11 +106,37 @@ async def load_from_bulk_export( ) await bulk_exporter.export() - # Copy back these settings from the export - self.group_name = bulk_exporter.group_name - self.export_datetime = bulk_exporter.export_datetime - except errors.FatalError as exc: errors.fatal(str(exc), errors.BULK_EXPORT_FAILED) - return target_dir + return base.LoaderResults( + directory=target_dir, + group_name=bulk_exporter.group_name, + export_datetime=bulk_exporter.export_datetime, + ) + + def read_deleted_ids(self, root: store.Root) -> dict[str, set[str]]: + """ + Reads any deleted IDs that a bulk export gave us. + + See https://hl7.org/fhir/uv/bulkdata/export.html for details. + """ + deleted_ids = {} + + subdir = store.Root(root.joinpath("deleted")) + bundles = common.read_resource_ndjson(subdir, "Bundle") + for bundle in bundles: + if bundle.get("type") != "transaction": + continue + for entry in bundle.get("entry", []): + request = entry.get("request", {}) + if request.get("method") != "DELETE": + continue + url = request.get("url") + # Sanity check that we have a relative URL like "Patient/123" + if not url or url.count("/") != 1: + continue + resource, res_id = url.split("/") + deleted_ids.setdefault(resource, set()).add(res_id) + + return deleted_ids diff --git a/cumulus_etl/loaders/i2b2/loader.py b/cumulus_etl/loaders/i2b2/loader.py index dd98617..42f7240 100644 --- a/cumulus_etl/loaders/i2b2/loader.py +++ b/cumulus_etl/loaders/i2b2/loader.py @@ -8,7 +8,7 @@ from typing import TypeVar from cumulus_etl import cli_utils, common, store -from cumulus_etl.loaders.base import Loader +from cumulus_etl.loaders import base from cumulus_etl.loaders.i2b2 import extract, schema, transform from cumulus_etl.loaders.i2b2.oracle import extract as oracle_extract @@ -18,7 +18,7 @@ I2b2ToFhirCallable = Callable[[AnyDimension], dict] -class I2b2Loader(Loader): +class I2b2Loader(base.Loader): """ Loader for i2b2 data. @@ -34,11 +34,12 @@ def __init__(self, root: store.Root, export_to: str | None = None): super().__init__(root) self.export_to = export_to - async def load_all(self, resources: list[str]) -> common.Directory: + async def load_all(self, resources: list[str]) -> base.LoaderResults: if self.root.protocol in ["tcp"]: - return self._load_all_from_oracle(resources) - - return self._load_all_from_csv(resources) + directory = self._load_all_from_oracle(resources) + else: + directory = self._load_all_from_csv(resources) + return base.LoaderResults(directory=directory) def _load_all_with_extractors( self, diff --git a/tests/convert/test_convert_cli.py b/tests/convert/test_convert_cli.py index c366267..e52b181 100644 --- a/tests/convert/test_convert_cli.py +++ b/tests/convert/test_convert_cli.py @@ -8,12 +8,11 @@ import ddt from cumulus_etl import cli, common, errors -from tests import utils +from tests import s3mock, utils -@ddt.ddt -class TestConvert(utils.AsyncTestCase): - """Tests for high-level convert support.""" +class ConvertTestsBase(utils.AsyncTestCase): + """Base class for convert tests""" def setUp(self): super().setUp() @@ -25,6 +24,21 @@ def setUp(self): self.original_path = os.path.join(self.tmpdir, "original") self.target_path = os.path.join(self.tmpdir, "target") + async def run_convert( + self, input_path: str | None = None, output_path: str | None = None + ) -> None: + args = [ + "convert", + input_path or self.original_path, + output_path or self.target_path, + ] + await cli.main(args) + + +@ddt.ddt +class TestConvert(ConvertTestsBase): + """Tests for high-level convert support.""" + def prepare_original_dir(self) -> str: """Returns the job timestamp used, for easier inspection""" # Fill in original dir, including a non-default output folder @@ -47,16 +61,6 @@ def prepare_original_dir(self) -> str: return job_timestamp - async def run_convert( - self, input_path: str | None = None, output_path: str | None = None - ) -> None: - args = [ - "convert", - input_path or self.original_path, - output_path or self.target_path, - ] - await cli.main(args) - async def test_input_dir_must_exist(self): """Verify that the input dir must already exist""" with self.assertRaises(SystemExit) as cm: @@ -150,7 +154,10 @@ async def test_happy_path(self): {"test": True}, common.read_json(f"{self.target_path}/JobConfig/{job_timestamp}/job_config.json"), ) - self.assertEqual({"delta": "yup"}, common.read_json(f"{delta_config_dir}/job_config.json")) + self.assertEqual( + {"delta": "yup"}, + common.read_json(f"{self.target_path}/JobConfig/{delta_timestamp}/job_config.json"), + ) patients = utils.read_delta_lake(f"{self.target_path}/patient") # re-check the patients self.assertEqual(3, len(patients)) # these rows are sorted by id, so these are reliable indexes @@ -211,3 +218,54 @@ async def test_batch_metadata(self, mock_write): ) # second (faked) covid batch self.assertEqual({"nonexistent"}, mock_write.call_args_list[2][0][0].groups) + + @mock.patch("cumulus_etl.formats.Format.write_records") + @mock.patch("cumulus_etl.formats.deltalake.DeltaLakeFormat.delete_records") + async def test_deleted_ids(self, mock_delete, mock_write): + """Verify that we pass along deleted IDs in the table metadata""" + # Set up input path + shutil.copytree( # First, one table that has no metadata + f"{self.datadir}/simple/output/patient", + f"{self.original_path}/patient", + ) + shutil.copytree( # Then, one that will + f"{self.datadir}/simple/output/condition", + f"{self.original_path}/condition", + ) + common.write_json(f"{self.original_path}/condition/condition.meta", {"deleted": ["a", "b"]}) + os.makedirs(f"{self.original_path}/JobConfig") + + # Run conversion + await self.run_convert() + + # Verify results + self.assertEqual(mock_write.call_count, 2) + self.assertEqual(mock_delete.call_count, 1) + self.assertEqual(mock_delete.call_args, mock.call({"a", "b"})) + + +class TestConvertOnS3(s3mock.S3Mixin, ConvertTestsBase): + @mock.patch("cumulus_etl.formats.Format.write_records") + async def test_convert_from_s3(self, mock_write): + """Quick test that we can read from an arbitrary input dir using fsspec""" + # Set up input + common.write_json( + f"{self.bucket_url}/JobConfig/2024-08-09__16.32.51/job_config.json", + {"comment": "unittest"}, + ) + common.write_json( + f"{self.bucket_url}/condition/condition.000.ndjson", + {"id": "con1"}, + ) + + # Run conversion + await self.run_convert(input_path=self.bucket_url) + + # Verify results + self.assertEqual(mock_write.call_count, 1) + self.assertEqual([{"id": "con1"}], mock_write.call_args[0][0].rows) + print(os.listdir(f"{self.target_path}")) + self.assertEqual( + common.read_json(f"{self.target_path}/JobConfig/2024-08-09__16.32.51/job_config.json"), + {"comment": "unittest"}, + ) diff --git a/tests/etl/test_etl_cli.py b/tests/etl/test_etl_cli.py index 35dc58b..c057927 100644 --- a/tests/etl/test_etl_cli.py +++ b/tests/etl/test_etl_cli.py @@ -267,20 +267,16 @@ async def test_task_init_checks(self, mock_check): async def test_completion_args(self, etl_args, loader_vals, expected_vals): """Verify that we parse completion args with the correct fallbacks and checks.""" # Grab all observations before we mock anything - observations = loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_all( + observations = await loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_all( ["Observation"] ) - - def fake_load_all(internal_self, resources): - del resources - internal_self.group_name = loader_vals[0] - internal_self.export_datetime = loader_vals[1] - return observations + observations.group_name = loader_vals[0] + observations.export_datetime = loader_vals[1] with ( self.assertRaises(SystemExit) as cm, mock.patch("cumulus_etl.etl.cli.etl_job", side_effect=SystemExit) as mock_etl_job, - mock.patch.object(loaders.FhirNdjsonLoader, "load_all", new=fake_load_all), + mock.patch.object(loaders.FhirNdjsonLoader, "load_all", return_value=observations), ): await self.run_etl(tasks=["observation"], **etl_args) @@ -291,6 +287,24 @@ def fake_load_all(internal_self, resources): self.assertEqual(expected_vals[0], config.export_group_name) self.assertEqual(expected_vals[1], config.export_datetime) + async def test_deleted_ids_passed_down(self): + """Verify that we parse pass along any deleted ids to the JobConfig.""" + with tempfile.TemporaryDirectory() as tmpdir: + results = loaders.LoaderResults( + directory=common.RealDirectory(tmpdir), deleted_ids={"Observation": {"obs1"}} + ) + + with ( + self.assertRaises(SystemExit), + mock.patch("cumulus_etl.etl.cli.etl_job", side_effect=SystemExit) as mock_etl_job, + mock.patch.object(loaders.FhirNdjsonLoader, "load_all", return_value=results), + ): + await self.run_etl(tasks=["observation"]) + + self.assertEqual(mock_etl_job.call_count, 1) + config = mock_etl_job.call_args[0][0] + self.assertEqual({"Observation": {"obs1"}}, config.deleted_ids) + class TestEtlJobConfig(BaseEtlSimple): """Test case for the job config logging data""" diff --git a/tests/etl/test_tasks.py b/tests/etl/test_tasks.py index 7bc406b..b3ac408 100644 --- a/tests/etl/test_tasks.py +++ b/tests/etl/test_tasks.py @@ -7,6 +7,7 @@ import ddt from cumulus_etl import common, errors +from cumulus_etl.etl import tasks from cumulus_etl.etl.tasks import basic_tasks, task_factory from tests.etl import TaskTestCase @@ -133,6 +134,46 @@ async def test_batch_is_given_schema(self): self.assertIn("address", schema.names) self.assertIn("id", schema.names) + async def test_get_schema(self): + """Verify that Task.get_schema() works for resources and non-resources""" + schema = tasks.EtlTask.get_schema("Patient", []) + self.assertIn("gender", schema.names) + schema = tasks.EtlTask.get_schema(None, []) + self.assertIsNone(schema) + + async def test_prepare_can_skip_task(self): + """Verify that if prepare_task returns false, we skip the task""" + self.make_json("Patient", "A") + with mock.patch( + "cumulus_etl.etl.tasks.basic_tasks.PatientTask.prepare_task", return_value=False + ): + summaries = await basic_tasks.PatientTask(self.job_config, self.scrubber).run() + self.assertEqual(len(summaries), 1) + self.assertEqual(summaries[0].attempt, 0) + self.assertIsNone(self.format) + + async def test_deleted_ids_no_op(self): + """Verify that we don't try to delete IDs if none are given""" + # Just a simple test to confirm we don't even ask the formatter to consider + # deleting any IDs if we weren't given any. + await basic_tasks.PatientTask(self.job_config, self.scrubber).run() + self.assertEqual(self.format.delete_records.call_count, 0) + + async def test_deleted_ids(self): + """Verify that we send deleted IDs down to the formatter""" + self.job_config.deleted_ids = {"Patient": {"p1", "p2"}} + await basic_tasks.PatientTask(self.job_config, self.scrubber).run() + + self.assertEqual(self.format.delete_records.call_count, 1) + ids = self.format.delete_records.call_args[0][0] + self.assertEqual( + ids, + { + self.codebook.db.resource_hash("p1"), + self.codebook.db.resource_hash("p2"), + }, + ) + @ddt.ddt class TestTaskCompletion(TaskTestCase): diff --git a/tests/formats/test_deltalake.py b/tests/formats/test_deltalake.py index e1cfb89..1404cfa 100644 --- a/tests/formats/test_deltalake.py +++ b/tests/formats/test_deltalake.py @@ -5,6 +5,7 @@ import os import shutil import tempfile +from unittest import mock import ddt import pyarrow @@ -395,3 +396,116 @@ def test_update_existing(self): self.store(self.df(a=1, b=2)) self.store(self.df(a=999, c=3), update_existing=False) self.assert_lake_equal(self.df(a=1, b=2, c=3)) + + def test_s3_options(self): + """Verify that we read in S3 options and set spark config appropriately""" + # Save global/class-wide spark object, to be restored. Then clear it out. + old_spark = DeltaLakeFormat.spark + + def restore_spark(): + DeltaLakeFormat.spark = old_spark + + self.addCleanup(restore_spark) + DeltaLakeFormat.spark = None + + # Now re-initialize the class, mocking out all the slow spark stuff, and using S3. + fs_options = { + "s3_kms_key": "test-key", + "s3_region": "us-west-1", + } + with ( + mock.patch("cumulus_etl.store._user_fs_options", new=fs_options), + mock.patch("delta.configure_spark_with_delta_pip"), + mock.patch("pyspark.sql"), + ): + DeltaLakeFormat.initialize_class(store.Root("s3://test/")) + + self.assertEqual( + sorted(DeltaLakeFormat.spark.conf.set.call_args_list, key=lambda x: x[0][0]), + [ + mock.call( + "fs.s3a.aws.credentials.provider", + "com.amazonaws.auth.DefaultAWSCredentialsProviderChain", + ), + mock.call("fs.s3a.endpoint.region", "us-west-1"), + mock.call("fs.s3a.server-side-encryption-algorithm", "SSE-KMS"), + mock.call("fs.s3a.server-side-encryption.key", "test-key"), + mock.call("fs.s3a.sse.enabled", "true"), + ], + ) + + def test_finalize_happy_path(self): + """Verify that we clean up the delta lake when finalizing.""" + # Limit our fake table to just these attributes, to notice any new usage in future + mock_table = mock.MagicMock(spec=["generate", "optimize", "vacuum"]) + self.patch("delta.DeltaTable.forPath", return_value=mock_table) + + DeltaLakeFormat(self.root, "patient").finalize() + self.assertEqual(mock_table.optimize.call_args_list, [mock.call()]) + self.assertEqual( + mock_table.optimize.return_value.executeCompaction.call_args_list, [mock.call()] + ) + self.assertEqual(mock_table.generate.call_args_list, [mock.call("symlink_format_manifest")]) + self.assertEqual(mock_table.vacuum.call_args_list, [mock.call()]) + + def test_finalize_cannot_load_table(self): + """Verify that we gracefully handle failing to read an existing table when finalizing.""" + # No table + deltalake = DeltaLakeFormat(self.root, "patient") + with self.assertNoLogs(): + deltalake.finalize() + self.assertFalse(os.path.exists(self.output_dir)) + + # Error loading the table + with self.assertLogs(level="ERROR") as logs: + with mock.patch("delta.DeltaTable.forPath", side_effect=ValueError): + deltalake.finalize() + self.assertEqual(len(logs.output), 1) + self.assertTrue( + logs.output[0].startswith("ERROR:root:Could not load Delta Lake table patient\n") + ) + + def test_finalize_error(self): + """Verify that we gracefully handle an error while finalizing.""" + self.store(self.df(a=1)) # create a simple table to load + with self.assertLogs(level="ERROR") as logs: + with mock.patch("delta.DeltaTable.optimize", side_effect=ValueError): + DeltaLakeFormat(self.root, "patient").finalize() + self.assertEqual(len(logs.output), 1) + self.assertTrue( + logs.output[0].startswith("ERROR:root:Could not finalize Delta Lake table patient\n") + ) + + def test_delete_records_happy_path(self): + """Verify that `delete_records` works in a basic way.""" + self.store(self.df(a=1, b=2, c=3, d=4)) + + deltalake = DeltaLakeFormat(self.root, "patient") + deltalake.delete_records({"a", "c"}) + deltalake.delete_records({"d"}) + deltalake.delete_records(set()) + + self.assert_lake_equal(self.df(b=2)) + + def test_delete_records_cannot_load_table(self): + """Verify we gracefully handle a missing table""" + deltalake = DeltaLakeFormat(self.root, "patient") + with self.assertNoLogs(): + deltalake.delete_records({"a"}) + self.assertFalse(os.path.exists(self.output_dir)) + + def test_delete_records_error(self): + """Verify that `delete_records` handles errors gracefully.""" + mock_table = mock.MagicMock(spec=["delete"]) + mock_table.delete.side_effect = ValueError + self.patch("delta.DeltaTable.forPath", return_value=mock_table) + + with self.assertLogs(level="ERROR") as logs: + DeltaLakeFormat(self.root, "patient").delete_records("a") + + self.assertEqual(len(logs.output), 1) + self.assertTrue( + logs.output[0].startswith( + "ERROR:root:Could not delete IDs from Delta Lake table patient\n" + ) + ) diff --git a/tests/formats/test_ndjson.py b/tests/formats/test_ndjson.py index 2c5e412..ebd65f1 100644 --- a/tests/formats/test_ndjson.py +++ b/tests/formats/test_ndjson.py @@ -56,3 +56,20 @@ def test_disallows_existing_files(self, files: None | list[str], is_ok: bool): else: with self.assertRaises(SystemExit): NdjsonFormat.initialize_class(self.root) + + def test_writes_deleted_ids(self): + """Verify that we write a table metadata file with deleted IDs""" + meta_path = f"{self.root.joinpath('condition')}/condition.meta" + + # Test with a fresh directory + formatter = NdjsonFormat(self.root, "condition") + formatter.delete_records({"b", "a"}) + metadata = common.read_json(meta_path) + self.assertEqual(metadata, {"deleted": ["a", "b"]}) + + # Confirm we append to existing metadata, should we ever need to + metadata["extra"] = "bonus metadata!" + common.write_json(meta_path, metadata) + formatter.delete_records({"c"}) + metadata = common.read_json(meta_path) + self.assertEqual(metadata, {"deleted": ["a", "b", "c"], "extra": "bonus metadata!"}) diff --git a/tests/loaders/i2b2/test_i2b2_loader.py b/tests/loaders/i2b2/test_i2b2_loader.py index bea7023..4e1c31f 100644 --- a/tests/loaders/i2b2/test_i2b2_loader.py +++ b/tests/loaders/i2b2/test_i2b2_loader.py @@ -4,7 +4,7 @@ import shutil import tempfile -from cumulus_etl import store +from cumulus_etl import common, store from cumulus_etl.loaders.i2b2 import loader from tests.utils import AsyncTestCase @@ -22,6 +22,22 @@ async def test_missing_files(self): vitals = f"{self.datadir}/i2b2/input/observation_fact_vitals.csv" shutil.copy(vitals, tmpdir) - loaded_dir = await i2b2_loader.load_all(["Observation", "Patient"]) + results = await i2b2_loader.load_all(["Observation", "Patient"]) - self.assertEqual(["Observation.1.ndjson"], os.listdir(loaded_dir.name)) + self.assertEqual(["Observation.1.ndjson"], os.listdir(results.path)) + + async def test_duplicate_ids(self): + """Verify that we ignore duplicate IDs""" + with tempfile.TemporaryDirectory() as tmpdir: + root = store.Root(tmpdir) + i2b2_loader = loader.I2b2Loader(root) + + common.write_text( + f"{tmpdir}/patient_dimension.csv", + "PATIENT_NUM,BIRTH_DATE\n" "123,1982-10-16\n" "123,1983-11-17\n" "456,2000-01-13\n", + ) + + results = await i2b2_loader.load_all(["Patient"]) + rows = common.read_resource_ndjson(store.Root(results.path), "Patient") + values = [(r["id"], r["birthDate"]) for r in rows] + self.assertEqual(values, [("123", "1982-10-16"), ("456", "2000-01-13")]) diff --git a/tests/loaders/i2b2/test_i2b2_oracle_extract.py b/tests/loaders/i2b2/test_i2b2_oracle_extract.py index 75f2d34..8c2071a 100644 --- a/tests/loaders/i2b2/test_i2b2_oracle_extract.py +++ b/tests/loaders/i2b2/test_i2b2_oracle_extract.py @@ -93,7 +93,7 @@ async def test_loader(self, mock_extract): root = store.Root("tcp://localhost/foo") oracle_loader = loader.I2b2Loader(root) - tmpdir = await oracle_loader.load_all(["Condition", "Encounter", "Patient"]) + results = await oracle_loader.load_all(["Condition", "Encounter", "Patient"]) # Check results self.assertEqual( @@ -102,17 +102,17 @@ async def test_loader(self, mock_extract): "Encounter.ndjson", "Patient.ndjson", }, - set(os.listdir(tmpdir.name)), + set(os.listdir(results.path)), ) self.assertEqual( i2b2_mock_data.condition(), - common.read_json(os.path.join(tmpdir.name, "Condition.ndjson")), + common.read_json(os.path.join(results.path, "Condition.ndjson")), ) self.assertEqual( i2b2_mock_data.encounter(), - common.read_json(os.path.join(tmpdir.name, "Encounter.ndjson")), + common.read_json(os.path.join(results.path, "Encounter.ndjson")), ) self.assertEqual( - i2b2_mock_data.patient(), common.read_json(os.path.join(tmpdir.name, "Patient.ndjson")) + i2b2_mock_data.patient(), common.read_json(os.path.join(results.path, "Patient.ndjson")) ) diff --git a/tests/loaders/ndjson/test_ndjson_loader.py b/tests/loaders/ndjson/test_ndjson_loader.py index a44b877..ea88e14 100644 --- a/tests/loaders/ndjson/test_ndjson_loader.py +++ b/tests/loaders/ndjson/test_ndjson_loader.py @@ -62,13 +62,13 @@ async def test_local_happy_path(self): writer.write(patient) loader = loaders.FhirNdjsonLoader(store.Root(tmpdir)) - loaded_dir = await loader.load_all(["Patient"]) + results = await loader.load_all(["Patient"]) - self.assertEqual(["Patient.ndjson"], os.listdir(loaded_dir.name)) - self.assertEqual(patient, common.read_json(f"{loaded_dir.name}/Patient.ndjson")) - self.assertEqual("G", loader.group_name) + self.assertEqual(["Patient.ndjson"], os.listdir(results.path)) + self.assertEqual(patient, common.read_json(f"{results.path}/Patient.ndjson")) + self.assertEqual("G", results.group_name) self.assertEqual( - datetime.datetime.fromisoformat("1999-03-14T14:12:10"), loader.export_datetime + datetime.datetime.fromisoformat("1999-03-14T14:12:10"), results.export_datetime ) # At some point, we do want to make this fatal. @@ -80,11 +80,11 @@ async def test_log_parsing_is_non_fatal(self): self._write_log_file(f"{tmpdir}/log.2.ndjson", "G2", "2002-02-02") loader = loaders.FhirNdjsonLoader(store.Root(tmpdir)) - await loader.load_all([]) + results = await loader.load_all([]) # We used neither log and didn't error out. - self.assertIsNone(loader.group_name) - self.assertIsNone(loader.export_datetime) + self.assertIsNone(results.group_name) + self.assertIsNone(results.export_datetime) @mock.patch("cumulus_etl.fhir.fhir_client.FhirClient") @mock.patch("cumulus_etl.etl.cli.loaders.FhirNdjsonLoader") @@ -299,7 +299,7 @@ async def fake_export() -> None: loader = loaders.FhirNdjsonLoader( store.Root("http://localhost:9999"), mock.AsyncMock(), export_to=target ) - folder = await loader.load_all(["Patient"]) + results = await loader.load_all(["Patient"]) # Confirm export folder still has the data (and log) we created above in the mock self.assertTrue(os.path.isdir(target)) @@ -309,9 +309,9 @@ async def fake_export() -> None: self.assertEqual({"eventId": "kickoff"}, common.read_json(f"{target}/log.ndjson")) # Confirm the returned dir has only the data (we don't want to confuse MS tool with logs) - self.assertNotEqual(folder.name, target) - self.assertEqual({"Patient.ndjson"}, set(os.listdir(folder.name))) - self.assertEqual(patient, common.read_json(f"{folder.name}/Patient.ndjson")) + self.assertNotEqual(results.path, target) + self.assertEqual({"Patient.ndjson"}, set(os.listdir(results.path))) + self.assertEqual(patient, common.read_json(f"{results.path}/Patient.ndjson")) async def test_export_internal_folder_happy_path(self): """Test that we can also safely export without an export-to folder involved""" @@ -325,11 +325,11 @@ async def fake_export() -> None: self.mock_exporter.export.side_effect = fake_export loader = loaders.FhirNdjsonLoader(store.Root("http://localhost:9999"), mock.AsyncMock()) - folder = await loader.load_all(["Patient"]) + results = await loader.load_all(["Patient"]) # Confirm the returned dir has only the data (we don't want to confuse MS tool with logs) - self.assertEqual({"Patient.ndjson"}, set(os.listdir(folder.name))) - self.assertEqual(patient, common.read_json(f"{folder.name}/Patient.ndjson")) + self.assertEqual({"Patient.ndjson"}, set(os.listdir(results.path))) + self.assertEqual(patient, common.read_json(f"{results.path}/Patient.ndjson")) async def test_export_to_folder_has_contents(self): """Verify we fail if an export folder already has contents""" @@ -350,3 +350,47 @@ async def test_export_to_folder_not_local(self): with self.assertRaises(SystemExit) as cm: await loader.load_all([]) self.assertEqual(cm.exception.code, errors.BULK_EXPORT_FOLDER_NOT_LOCAL) + + async def test_reads_deleted_ids(self): + """Verify we read in the deleted/ folder""" + with tempfile.TemporaryDirectory() as tmpdir: + os.mkdir(f"{tmpdir}/deleted") + common.write_json( + f"{tmpdir}/deleted/deletes.ndjson", + { + "resourceType": "Bundle", + "type": "transaction", + "entry": [ + {"request": {"method": "GET", "url": "Patient/bad-method"}}, + {"request": {"method": "DELETE", "url": "Patient/pat1"}}, + {"request": {"method": "DELETE", "url": "Patient/too/many/slashes"}}, + {"request": {"method": "DELETE", "url": "Condition/con1"}}, + {"request": {"method": "DELETE", "url": "Condition/con2"}}, + ], + }, + ) + # This next bundle will be ignored because of the wrong "type" + common.write_json( + f"{tmpdir}/deleted/messages.ndjson", + { + "resourceType": "Bundle", + "type": "message", + "entry": [ + { + "request": {"method": "DELETE", "url": "Patient/wrong-message-type"}, + } + ], + }, + ) + # This next file will be ignored because of the wrong "resourceType" + common.write_json( + f"{tmpdir}/deleted/conditions-for-some-reason.ndjson", + { + "resourceType": "Condition", + "recordedDate": "2024-09-04", + }, + ) + loader = loaders.FhirNdjsonLoader(store.Root(tmpdir)) + results = await loader.load_all(["Patient"]) + + self.assertEqual(results.deleted_ids, {"Patient": {"pat1"}, "Condition": {"con1", "con2"}})