diff --git a/src/adam_test_data/main.py b/src/adam_test_data/main.py index bf0151f..88aeb51 100644 --- a/src/adam_test_data/main.py +++ b/src/adam_test_data/main.py @@ -429,7 +429,7 @@ def sorcha( randomization: bool = True, output_columns: Literal["basic", "all"] = "all", cleanup: bool = True, -) -> tuple[Union[SorchaOutputBasic, SorchaOutputAll], SorchaOutputStats]: +) -> SourceCatalog: """ Run sorcha on the given small bodies, pointings, and observatory. @@ -458,9 +458,8 @@ def sorcha( Returns ------- - tuple[Union[SorchaOutputBasic, SorchaOutputAll], SorchaOutputStats] - Sorcha output observations (in basic or all formats) and statistics - per object and filter. + SourceCatalog + Sorcha outputs as an adam_core SourceCatalog. """ # Create the output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) @@ -506,7 +505,6 @@ def sorcha( ) else: output_file = f"{output_dir}/{tag}.csv" - output_stats_file = f"{output_dir}/{stats_file}.csv" sorcha_output_table: Union[Type[SorchaOutputBasic], Type[SorchaOutputAll]] if output_columns == "basic": @@ -520,15 +518,18 @@ def sorcha( sorcha_outputs: Union[SorchaOutputBasic, SorchaOutputAll] if os.path.exists(output_file): sorcha_outputs = sorcha_output_table.from_csv(output_file) - sorcha_stats = SorchaOutputStats.from_csv(output_stats_file) + source_catalog = sorcha_outputs.to_source_catalog( + catalog_id=tag, observatory_code=observatory.code + ) + else: sorcha_outputs = sorcha_output_table.empty() - sorcha_stats = SorchaOutputStats.empty() + source_catalog = SourceCatalog.empty() if cleanup: shutil.rmtree(output_dir) - return sorcha_outputs, sorcha_stats + return source_catalog def sorcha_worker( @@ -544,7 +545,7 @@ def sorcha_worker( randomization: bool = True, output_columns: Literal["basic", "all"] = "all", cleanup: bool = True, -) -> tuple[str, str]: +) -> str: """ Run sorcha on a subset of the input small bodies. @@ -575,8 +576,8 @@ def sorcha_worker( Returns ------- - tuple[str, str] - The paths to the Sorcha output files. + str + The path to the SourceCatalog output file. """ orbit_ids_chunk = orbit_ids[orbit_ids_indices[0] : orbit_ids_indices[1]] @@ -588,7 +589,7 @@ def sorcha_worker( chunk_base = f"chunk_{orbit_ids_indices[0]:08d}_{orbit_ids_indices[1]:08d}" output_dir_chunk = os.path.join(output_dir, chunk_base) - sorcha_outputs, sorcha_stats = sorcha( + catalog = sorcha( output_dir_chunk, small_bodies_chunk, pointings, @@ -602,13 +603,10 @@ def sorcha_worker( ) # Serialize the output tables to parquet and return the paths - sorcha_output_file = os.path.join(output_dir, f"{chunk_base}_{tag}.parquet") - sorcha_stats_file = os.path.join(output_dir, f"{chunk_base}_{tag}_stats.parquet") + catalog_file = os.path.join(output_dir, f"{chunk_base}_{tag}.parquet") + catalog.to_parquet(catalog_file) - sorcha_outputs.to_parquet(sorcha_output_file) - sorcha_stats.to_parquet(sorcha_stats_file) - - return sorcha_output_file, sorcha_stats_file + return catalog_file sorcha_worker_remote = ray.remote(sorcha_worker) @@ -628,7 +626,7 @@ def run_sorcha( chunk_size: int = 1000, max_processes: Optional[int] = 1, cleanup: bool = True, -) -> tuple[str, str]: +) -> str: """ Run sorcha on the given small bodies, pointings, and observatory. @@ -661,39 +659,25 @@ def run_sorcha( Returns ------- - tuple[Union[SorchaOutputBasic, SorchaOutputAll], SorchaOutputStats] - Sorcha output observations (in basic or all formats) and statistics - per object and filter. + str + The path to the SourceCatalog output file. """ if max_processes is None: max_processes = mp.cpu_count() orbit_ids = small_bodies.orbits.orbit_id - sorcha_outputs: Union[SorchaOutputBasic, SorchaOutputAll] - if output_columns == "basic": - sorcha_outputs = SorchaOutputBasic.empty() - else: - sorcha_outputs = SorchaOutputAll.empty() - sorcha_stats = SorchaOutputStats.empty() - # Create the output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) - sorcha_output_file = os.path.join(output_dir, f"{tag}.parquet") - sorcha_stats_file = os.path.join(output_dir, f"{tag}_stats.parquet") - - # Write the empty tables to parquet - sorcha_outputs.to_parquet(sorcha_output_file) - sorcha_stats.to_parquet(sorcha_stats_file) + catalog = SourceCatalog.empty() + catalog_file = os.path.join(output_dir, f"{tag}.parquet") + catalog.to_parquet(catalog_file) - sorcha_outputs_writer = pq.ParquetWriter( - sorcha_output_file, - sorcha_outputs.schema, - ) - sorcha_stats_writer = pq.ParquetWriter( - sorcha_stats_file, - sorcha_stats.schema, + # Create a Parquet writer for the output catalog + catalog_writer = pq.ParquetWriter( + catalog_file, + catalog.schema, ) use_ray = initialize_use_ray(num_cpus=max_processes) @@ -728,42 +712,28 @@ def run_sorcha( if len(futures) >= max_processes * 1.5: finished, futures = ray.wait(futures, num_returns=1) - sorcha_outputs_chunk_file, sorcha_stats_chunk_file = ray.get( - finished[0] - ) - - sorcha_outputs_chunk = sorcha_outputs.from_parquet( - sorcha_outputs_chunk_file - ) - sorcha_stats_chunk = sorcha_stats.from_parquet(sorcha_stats_chunk_file) + catalog_chunk_file = ray.get(finished[0]) - sorcha_outputs_writer.write_table(sorcha_outputs_chunk.table) - sorcha_stats_writer.write_table(sorcha_stats_chunk.table) + catalog_chunk = SourceCatalog.from_parquet(catalog_chunk_file) + catalog_writer.write_table(catalog_chunk.table) if cleanup: - os.remove(sorcha_outputs_chunk_file) - os.remove(sorcha_stats_chunk_file) + os.remove(catalog_chunk_file) while futures: finished, futures = ray.wait(futures, num_returns=1) - sorcha_outputs_chunk_file, sorcha_stats_chunk_file = ray.get(finished[0]) + catalog_chunk_file = ray.get(finished[0]) - sorcha_outputs_chunk = sorcha_outputs.from_parquet( - sorcha_outputs_chunk_file - ) - sorcha_stats_chunk = sorcha_stats.from_parquet(sorcha_stats_chunk_file) - - sorcha_outputs_writer.write_table(sorcha_outputs_chunk.table) - sorcha_stats_writer.write_table(sorcha_stats_chunk.table) + catalog_chunk = SourceCatalog.from_parquet(catalog_chunk_file) + catalog_writer.write_table(catalog_chunk.table) if cleanup: - os.remove(sorcha_outputs_chunk_file) - os.remove(sorcha_stats_chunk_file) + os.remove(catalog_chunk_file) else: for orbit_ids_indices in _iterate_chunk_indices(orbit_ids, chunk_size): - sorcha_outputs_chunk_file, sorcha_stats_chunk_file = sorcha_worker( + catalog_chunk_file = sorcha_worker( orbit_ids, orbit_ids_indices, output_dir, @@ -778,19 +748,13 @@ def run_sorcha( cleanup=cleanup, ) - sorcha_outputs_chunk = sorcha_outputs.from_parquet( - sorcha_outputs_chunk_file - ) - sorcha_stats_chunk = sorcha_stats.from_parquet(sorcha_stats_chunk_file) - - sorcha_outputs_writer.write_table(sorcha_outputs_chunk.table) - sorcha_stats_writer.write_table(sorcha_stats_chunk.table) + catalog_chunk = SourceCatalog.from_parquet(catalog_chunk_file) + catalog_writer.write_table(catalog_chunk.table) if cleanup: - os.remove(sorcha_outputs_chunk_file) - os.remove(sorcha_stats_chunk_file) + os.remove(catalog_chunk_file) - return sorcha_output_file, sorcha_stats_file + return catalog_file def generate_test_data( @@ -865,7 +829,7 @@ def generate_test_data( os.makedirs(output_dir, exist_ok=True) # Run sorcha - sorcha_outputs_file, sorcha_stats_file = run_sorcha( + catalog_file = run_sorcha( output_dir, small_bodies, pointings_filtered, @@ -883,8 +847,7 @@ def generate_test_data( noise_files: dict[str, str] = {} if noise_densities is None: return ( - sorcha_outputs_file, - sorcha_stats_file, + catalog_file, noise_files, ) @@ -904,4 +867,4 @@ def generate_test_data( ) noise_files[f"{noise_density:.2f}"] = sorcha_outputs_noisy - return sorcha_outputs_file, sorcha_stats_file, noise_files + return catalog_file, noise_files diff --git a/src/adam_test_data/tests/test_main.py b/src/adam_test_data/tests/test_main.py index 23f21ca..eb90d3b 100644 --- a/src/adam_test_data/tests/test_main.py +++ b/src/adam_test_data/tests/test_main.py @@ -1,6 +1,5 @@ import os import tempfile -from typing import Union import numpy as np import pyarrow as pa @@ -8,17 +7,11 @@ import pytest from adam_core.coordinates import CartesianCoordinates from adam_core.coordinates.origin import Origin +from adam_core.observations import SourceCatalog from adam_core.orbits import Orbits from adam_core.time import Timestamp -from ..main import ( - SorchaOutputAll, - SorchaOutputBasic, - SorchaOutputStats, - generate_test_data, - sorcha, - write_sorcha_inputs, -) +from ..main import generate_test_data, sorcha, write_sorcha_inputs from ..noise import NoiseCatalog from ..observatory import FieldOfView, Observatory, Simulation from ..pointings import Pointings @@ -190,8 +183,7 @@ def test_sorcha( # Test that _run_sorcha runs without error and returns least 6 observations with tempfile.TemporaryDirectory() as out_dir: - sorcha_outputs: Union[SorchaOutputAll, SorchaOutputBasic] - sorcha_outputs, sorcha_stats = sorcha( + catalog = sorcha( out_dir, small_bodies, pointings, @@ -199,16 +191,13 @@ def test_sorcha( randomization=False, output_columns="all", ) - assert len(sorcha_outputs) == 6 - if isinstance(sorcha_outputs, SorchaOutputAll): - assert pc.all( - pc.equal( - sorcha_outputs.FieldID, - pa.array(["exp00", "exp01", "exp02", "exp03", "exp04", "exp05"]), - ) + assert len(catalog) == 6 + assert pc.all( + pc.equal( + catalog.exposure_id, + pa.array(["exp00", "exp01", "exp02", "exp03", "exp04", "exp05"]), ) - - assert len(sorcha_stats) == 6 # One row for each object and filter + ) def test_generate_test_data_no_noise( @@ -217,7 +206,7 @@ def test_generate_test_data_no_noise( with tempfile.TemporaryDirectory() as out_dir: - results = generate_test_data( + catalog_file, noise_files = generate_test_data( out_dir, small_bodies, pointings, @@ -229,17 +218,13 @@ def test_generate_test_data_no_noise( cleanup=True, ) - assert len(results) == 3 - - sorcha_outputs = SorchaOutputAll.from_parquet(results[0]) - sorcha_stats = SorchaOutputStats.from_parquet(results[1]) - assert len(sorcha_outputs) == 6 - assert len(sorcha_stats) == 6 # One row for each object and filter - assert len(os.listdir(out_dir)) == 2 # There should be only two files + catalog = SourceCatalog.from_parquet(catalog_file) + assert len(catalog) == 6 + assert len(os.listdir(out_dir)) == 1 # There should be only two files with tempfile.TemporaryDirectory() as out_dir: - results = generate_test_data( + catalog_file, noise_files = generate_test_data( out_dir, small_bodies, pointings, @@ -252,14 +237,10 @@ def test_generate_test_data_no_noise( cleanup=False, ) - assert len(results) == 3 - - sorcha_outputs = SorchaOutputAll.from_parquet(results[0]) - sorcha_stats = SorchaOutputStats.from_parquet(results[1]) - assert len(sorcha_outputs) == 6 - assert len(sorcha_stats) == 6 # One row for each object and filter - assert len(results[2]) == 0 - assert len(os.listdir(out_dir)) == 5 # If we are not cleaning up + catalog = SourceCatalog.from_parquet(catalog_file) + assert len(catalog) == 6 + assert len(noise_files) == 0 + assert len(os.listdir(out_dir)) == 3 # If we are not cleaning up # then we expect there to be more files including the chunked partition # files and also a directory for those chunks @@ -270,7 +251,7 @@ def test_generate_test_data_with_noise( with tempfile.TemporaryDirectory() as out_dir: - results = generate_test_data( + catalog_file, noise_files = generate_test_data( out_dir, small_bodies, pointings, @@ -284,24 +265,20 @@ def test_generate_test_data_with_noise( cleanup=True, ) - assert len(results) == 3 - - sorcha_outputs = SorchaOutputAll.from_parquet(results[0]) - sorcha_stats = SorchaOutputStats.from_parquet(results[1]) - assert len(sorcha_outputs) == 6 - assert len(sorcha_stats) == 6 # One row for each object and filter - assert len(results[2]) == 2 - assert "100.00" in results[2] - assert "1000.00" in results[2] - assert len(os.listdir(out_dir)) == 4 # There should be only 4 files - # 2 parquet files for sorcha outputs and 2 for noise outputs. + catalog = SourceCatalog.from_parquet(catalog_file) + assert len(catalog) == 6 + assert len(noise_files) == 2 + assert "100.00" in noise_files + assert "1000.00" in noise_files + assert len(os.listdir(out_dir)) == 3 # There should be only 3 files + # 1 parquet file for sorcha outputs and 2 for noise outputs. - noise100 = NoiseCatalog.from_parquet(results[2]["100.00"]) + noise100 = NoiseCatalog.from_parquet(noise_files["100.00"]) expected_noise_detections = 6 * 100 * 1.75**2 * np.pi assert len(noise100) >= 0.9 * expected_noise_detections assert len(noise100) <= 1.1 * expected_noise_detections - noise1000 = NoiseCatalog.from_parquet(results[2]["1000.00"]) + noise1000 = NoiseCatalog.from_parquet(noise_files["1000.00"]) expected_noise_detections = 6 * 1000 * 1.75**2 * np.pi assert len(noise1000) >= 0.9 * expected_noise_detections assert len(noise1000) <= 1.1 * expected_noise_detections