Skip to content

Commit

Permalink
Add TestDataSummary.load_test_data convenience function
Browse files Browse the repository at this point in the history
  • Loading branch information
moeyensj committed Sep 20, 2024
1 parent cb9bb31 commit 4e707b6
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,8 @@ catalog_file, noise_files, summary = generate_test_data(
chunk_size=500,
cleanup=True
)

# Load in the generated test data with the desired noise density
catalog = summary.load_test_data("NSC_W84_P9", 100)

```
51 changes: 51 additions & 0 deletions src/adam_test_data/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,57 @@ class TestDataSummary(qv.Table):
catalog_file = qv.StringColumn()
noise_file = qv.StringColumn(nullable=True)

def load_test_data(
self, catalog_id: str, noise_density: Optional[float] = None
) -> SourceCatalog:
"""
Convenience method to load the test data for a given catalog_id and noise_density.
Parameters
----------
catalog_id : str
The ID of the catalog.
noise_density : float, optional
The noise observations at a particular density
to load with the catalog, by default None.
Returns
-------
catalog : SourceCatalog
The test data catalog (with noise if noise_density is not None).
"""
catalog_summary = self.select("catalog_id", catalog_id)
if len(catalog_summary) == 0:
err = (
f"No catalog found for catalog_id={catalog_id}\n"
"Options are:\n"
f"{self.catalog_id.unique().to_pylist()}"
)
raise ValueError(err)

catalog = SourceCatalog.from_parquet(catalog_summary.catalog_file[0].as_py())

if noise_density is not None:
catalog_summary_noise = catalog_summary.select(
"noise_density", noise_density
)

if len(catalog_summary_noise) == 0:
err = (
f"No noise catalog found for catalog_id={catalog_id}"
f" and noise_density={noise_density}\n"
"Options are:\n"
f"{catalog_summary.noise_density.unique().to_pylist()}"
)
raise ValueError(err)

noise_catalog = SourceCatalog.from_parquet(
catalog_summary_noise.noise_file[0].as_py()
)
catalog = qv.concatenate([catalog, noise_catalog])

return catalog


def remove_quotes(file_path: str) -> None:
"""
Expand Down

0 comments on commit 4e707b6

Please sign in to comment.