Skip to content

Commit

Permalink
Return SourceCatalog instead of SorchaOutputs
Browse files Browse the repository at this point in the history
  • Loading branch information
moeyensj committed Sep 19, 2024
1 parent 16409d3 commit 7baa528
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 130 deletions.
121 changes: 42 additions & 79 deletions src/adam_test_data/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def sorcha(
randomization: bool = True,
output_columns: Literal["basic", "all"] = "all",
cleanup: bool = True,
) -> tuple[Union[SorchaOutputBasic, SorchaOutputAll], SorchaOutputStats]:
) -> SourceCatalog:
"""
Run sorcha on the given small bodies, pointings, and observatory.
Expand Down Expand Up @@ -458,9 +458,8 @@ def sorcha(
Returns
-------
tuple[Union[SorchaOutputBasic, SorchaOutputAll], SorchaOutputStats]
Sorcha output observations (in basic or all formats) and statistics
per object and filter.
SourceCatalog
Sorcha outputs as an adam_core SourceCatalog.
"""
# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
Expand Down Expand Up @@ -506,7 +505,6 @@ def sorcha(
)
else:
output_file = f"{output_dir}/{tag}.csv"
output_stats_file = f"{output_dir}/{stats_file}.csv"

sorcha_output_table: Union[Type[SorchaOutputBasic], Type[SorchaOutputAll]]
if output_columns == "basic":
Expand All @@ -520,15 +518,18 @@ def sorcha(
sorcha_outputs: Union[SorchaOutputBasic, SorchaOutputAll]
if os.path.exists(output_file):
sorcha_outputs = sorcha_output_table.from_csv(output_file)
sorcha_stats = SorchaOutputStats.from_csv(output_stats_file)
source_catalog = sorcha_outputs.to_source_catalog(
catalog_id=tag, observatory_code=observatory.code
)

else:
sorcha_outputs = sorcha_output_table.empty()
sorcha_stats = SorchaOutputStats.empty()
source_catalog = SourceCatalog.empty()

if cleanup:
shutil.rmtree(output_dir)

return sorcha_outputs, sorcha_stats
return source_catalog


def sorcha_worker(
Expand All @@ -544,7 +545,7 @@ def sorcha_worker(
randomization: bool = True,
output_columns: Literal["basic", "all"] = "all",
cleanup: bool = True,
) -> tuple[str, str]:
) -> str:
"""
Run sorcha on a subset of the input small bodies.
Expand Down Expand Up @@ -575,8 +576,8 @@ def sorcha_worker(
Returns
-------
tuple[str, str]
The paths to the Sorcha output files.
str
The path to the SourceCatalog output file.
"""
orbit_ids_chunk = orbit_ids[orbit_ids_indices[0] : orbit_ids_indices[1]]

Expand All @@ -588,7 +589,7 @@ def sorcha_worker(
chunk_base = f"chunk_{orbit_ids_indices[0]:08d}_{orbit_ids_indices[1]:08d}"
output_dir_chunk = os.path.join(output_dir, chunk_base)

sorcha_outputs, sorcha_stats = sorcha(
catalog = sorcha(
output_dir_chunk,
small_bodies_chunk,
pointings,
Expand All @@ -602,13 +603,10 @@ def sorcha_worker(
)

# Serialize the output tables to parquet and return the paths
sorcha_output_file = os.path.join(output_dir, f"{chunk_base}_{tag}.parquet")
sorcha_stats_file = os.path.join(output_dir, f"{chunk_base}_{tag}_stats.parquet")
catalog_file = os.path.join(output_dir, f"{chunk_base}_{tag}.parquet")
catalog.to_parquet(catalog_file)

sorcha_outputs.to_parquet(sorcha_output_file)
sorcha_stats.to_parquet(sorcha_stats_file)

return sorcha_output_file, sorcha_stats_file
return catalog_file


sorcha_worker_remote = ray.remote(sorcha_worker)
Expand All @@ -628,7 +626,7 @@ def run_sorcha(
chunk_size: int = 1000,
max_processes: Optional[int] = 1,
cleanup: bool = True,
) -> tuple[str, str]:
) -> str:
"""
Run sorcha on the given small bodies, pointings, and observatory.
Expand Down Expand Up @@ -661,39 +659,25 @@ def run_sorcha(
Returns
-------
tuple[Union[SorchaOutputBasic, SorchaOutputAll], SorchaOutputStats]
Sorcha output observations (in basic or all formats) and statistics
per object and filter.
str
The path to the SourceCatalog output file.
"""
if max_processes is None:
max_processes = mp.cpu_count()

orbit_ids = small_bodies.orbits.orbit_id

sorcha_outputs: Union[SorchaOutputBasic, SorchaOutputAll]
if output_columns == "basic":
sorcha_outputs = SorchaOutputBasic.empty()
else:
sorcha_outputs = SorchaOutputAll.empty()
sorcha_stats = SorchaOutputStats.empty()

# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

sorcha_output_file = os.path.join(output_dir, f"{tag}.parquet")
sorcha_stats_file = os.path.join(output_dir, f"{tag}_stats.parquet")

# Write the empty tables to parquet
sorcha_outputs.to_parquet(sorcha_output_file)
sorcha_stats.to_parquet(sorcha_stats_file)
catalog = SourceCatalog.empty()
catalog_file = os.path.join(output_dir, f"{tag}.parquet")
catalog.to_parquet(catalog_file)

sorcha_outputs_writer = pq.ParquetWriter(
sorcha_output_file,
sorcha_outputs.schema,
)
sorcha_stats_writer = pq.ParquetWriter(
sorcha_stats_file,
sorcha_stats.schema,
# Create a Parquet writer for the output catalog
catalog_writer = pq.ParquetWriter(
catalog_file,
catalog.schema,
)

use_ray = initialize_use_ray(num_cpus=max_processes)
Expand Down Expand Up @@ -728,42 +712,28 @@ def run_sorcha(

if len(futures) >= max_processes * 1.5:
finished, futures = ray.wait(futures, num_returns=1)
sorcha_outputs_chunk_file, sorcha_stats_chunk_file = ray.get(
finished[0]
)

sorcha_outputs_chunk = sorcha_outputs.from_parquet(
sorcha_outputs_chunk_file
)
sorcha_stats_chunk = sorcha_stats.from_parquet(sorcha_stats_chunk_file)
catalog_chunk_file = ray.get(finished[0])

sorcha_outputs_writer.write_table(sorcha_outputs_chunk.table)
sorcha_stats_writer.write_table(sorcha_stats_chunk.table)
catalog_chunk = SourceCatalog.from_parquet(catalog_chunk_file)
catalog_writer.write_table(catalog_chunk.table)

if cleanup:
os.remove(sorcha_outputs_chunk_file)
os.remove(sorcha_stats_chunk_file)
os.remove(catalog_chunk_file)

while futures:
finished, futures = ray.wait(futures, num_returns=1)
sorcha_outputs_chunk_file, sorcha_stats_chunk_file = ray.get(finished[0])
catalog_chunk_file = ray.get(finished[0])

sorcha_outputs_chunk = sorcha_outputs.from_parquet(
sorcha_outputs_chunk_file
)
sorcha_stats_chunk = sorcha_stats.from_parquet(sorcha_stats_chunk_file)

sorcha_outputs_writer.write_table(sorcha_outputs_chunk.table)
sorcha_stats_writer.write_table(sorcha_stats_chunk.table)
catalog_chunk = SourceCatalog.from_parquet(catalog_chunk_file)
catalog_writer.write_table(catalog_chunk.table)

if cleanup:
os.remove(sorcha_outputs_chunk_file)
os.remove(sorcha_stats_chunk_file)
os.remove(catalog_chunk_file)

else:

for orbit_ids_indices in _iterate_chunk_indices(orbit_ids, chunk_size):
sorcha_outputs_chunk_file, sorcha_stats_chunk_file = sorcha_worker(
catalog_chunk_file = sorcha_worker(
orbit_ids,
orbit_ids_indices,
output_dir,
Expand All @@ -778,19 +748,13 @@ def run_sorcha(
cleanup=cleanup,
)

sorcha_outputs_chunk = sorcha_outputs.from_parquet(
sorcha_outputs_chunk_file
)
sorcha_stats_chunk = sorcha_stats.from_parquet(sorcha_stats_chunk_file)

sorcha_outputs_writer.write_table(sorcha_outputs_chunk.table)
sorcha_stats_writer.write_table(sorcha_stats_chunk.table)
catalog_chunk = SourceCatalog.from_parquet(catalog_chunk_file)
catalog_writer.write_table(catalog_chunk.table)

if cleanup:
os.remove(sorcha_outputs_chunk_file)
os.remove(sorcha_stats_chunk_file)
os.remove(catalog_chunk_file)

return sorcha_output_file, sorcha_stats_file
return catalog_file


def generate_test_data(
Expand Down Expand Up @@ -865,7 +829,7 @@ def generate_test_data(
os.makedirs(output_dir, exist_ok=True)

# Run sorcha
sorcha_outputs_file, sorcha_stats_file = run_sorcha(
catalog_file = run_sorcha(
output_dir,
small_bodies,
pointings_filtered,
Expand All @@ -883,8 +847,7 @@ def generate_test_data(
noise_files: dict[str, str] = {}
if noise_densities is None:
return (
sorcha_outputs_file,
sorcha_stats_file,
catalog_file,
noise_files,
)

Expand All @@ -904,4 +867,4 @@ def generate_test_data(
)
noise_files[f"{noise_density:.2f}"] = sorcha_outputs_noisy

return sorcha_outputs_file, sorcha_stats_file, noise_files
return catalog_file, noise_files
Loading

0 comments on commit 7baa528

Please sign in to comment.