diff --git a/.github/workflows/build_docs.yaml b/.github/workflows/build_docs.yaml index cf983fa1..c81a7b15 100644 --- a/.github/workflows/build_docs.yaml +++ b/.github/workflows/build_docs.yaml @@ -42,7 +42,7 @@ jobs: shell: bash -l {0} run: | export RUBIN_SIM_DATA_DIR=${{ github.workspace }} - rs_download_data --tdqm_disable -d 'scheduler' + scheduler_download_data --tdqm_disable -d 'scheduler' - name: check conda and documenteer shell: bash -l {0} diff --git a/.github/workflows/build_pypi.yaml b/.github/workflows/build_pypi.yaml index f63d57fd..e32a5c10 100644 --- a/.github/workflows/build_pypi.yaml +++ b/.github/workflows/build_pypi.yaml @@ -46,8 +46,8 @@ jobs: shell: bash -l {0} run: | export RUBIN_SIM_DATA_DIR=${{ github.workspace }}/data_dir - rs_download_data --tdqm_disable -d 'site_models,skybrightness_pre,scheduler' - rs_download_data --tdqm_disable -d tests --force + scheduler_download_data --tdqm_disable -d 'site_models,skybrightness_pre,scheduler' + scheduler_download_data --tdqm_disable -d tests --force - name: conda list shell: bash -l {0} diff --git a/.github/workflows/run_all_tests.yaml b/.github/workflows/run_all_tests.yaml index 223705ce..72d4739f 100644 --- a/.github/workflows/run_all_tests.yaml +++ b/.github/workflows/run_all_tests.yaml @@ -47,7 +47,7 @@ jobs: shell: bash -l {0} run: | export RUBIN_SIM_DATA_DIR=${{ github.workspace }}/data_dir - rs_download_data --force --tdqm_disable + scheduler_download_data --force --tdqm_disable - name: conda list shell: bash -l {0} diff --git a/.github/workflows/run_tests_docs.yaml b/.github/workflows/run_tests_docs.yaml index cb14165a..1414e528 100644 --- a/.github/workflows/run_tests_docs.yaml +++ b/.github/workflows/run_tests_docs.yaml @@ -49,8 +49,8 @@ jobs: shell: bash -l {0} run: | export RUBIN_SIM_DATA_DIR=${{ github.workspace }}/data_dir - rs_download_data --tdqm_disable -d 'site_models,scheduler,skybrightness_pre' - rs_download_data --tdqm_disable -d tests --force + scheduler_download_data --tdqm_disable -d 'site_models,scheduler,skybrightness_pre' + scheduler_download_data --tdqm_disable -d tests --force - name: conda list shell: bash -l {0} @@ -94,7 +94,7 @@ jobs: shell: bash -l {0} run: | export RUBIN_SIM_DATA_DIR=${{ github.workspace }} - rs_download_data --tdqm_disable -d 'scheduler' + scheduler_download_data --tdqm_disable -d 'scheduler' - name: check conda and documenteer shell: bash -l {0} diff --git a/.gitignore b/.gitignore index 68bc17f9..e09d322f 100644 --- a/.gitignore +++ b/.gitignore @@ -146,6 +146,15 @@ dmypy.json # Pyre type checker .pyre/ +# version file +version.py + +# pycharm files +.idea/ + +# vscode files +.vscode/ + # pytype static type analyzer .pytype/ diff --git a/pyproject.toml b/pyproject.toml index b672b140..eeb91eee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering :: Astronomy", ] -urls = {documentation = "https://rubin-sim.lsst.io", repository = "https://github.com/lsst/rubin_sim" } +urls = {documentation = "https://rubin-scheduler.lsst.io", repository = "https://github.com/lsst/rubin_scheduler" } dynamic = [ "version" ] dependencies = [ "numpy", @@ -52,7 +52,7 @@ dev = [ ] [project.scripts] -rs_download_data = "rubin_scheduler.data.rs_download_data:rs_download_data" +scheduler_download_data = "rubin_scheduler.data.scheduler_download_data:scheduler_download_data" rs_download_sky = "rubin_scheduler.data.rs_download_sky:rs_download_sky" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..3fba1667 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +setuptools_scm +setuptools_scm_git_archive +numpy +matplotlib +healpy +pandas +numexpr +palpy +scipy +sqlalchemy +astropy +pytables +h5py +openorb +openorb-data-de405 +astroplan +colorcet +cycler +george +scikit-learn +requests +tqdm diff --git a/rubin_scheduler/__init__.py b/rubin_scheduler/__init__.py new file mode 100644 index 00000000..58f3ace6 --- /dev/null +++ b/rubin_scheduler/__init__.py @@ -0,0 +1 @@ +from .version import __version__ diff --git a/rubin_scheduler/data/__init__.py b/rubin_scheduler/data/__init__.py new file mode 100644 index 00000000..052dbdd3 --- /dev/null +++ b/rubin_scheduler/data/__init__.py @@ -0,0 +1 @@ +from .data_sets import * # noqa: F403 diff --git a/rubin_scheduler/data/data_sets.py b/rubin_scheduler/data/data_sets.py new file mode 100644 index 00000000..49a832b8 --- /dev/null +++ b/rubin_scheduler/data/data_sets.py @@ -0,0 +1,62 @@ +__all__ = ("get_data_dir", "data_versions", "get_baseline") + +import glob +import os + + +def get_data_dir(): + """Get the location of the rubin_sim data directory. + + Returns + ------- + data_dir : `str` + Path to the rubin_sim data directory. + """ + # See if there is an environment variable with the path + data_dir = os.getenv("RUBIN_SIM_DATA_DIR") + + # Set the root data directory + if data_dir is None: + data_dir = os.path.join(os.getenv("HOME"), "rubin_sim_data") + return data_dir + + +def get_baseline(): + """Get the path to the baseline cadence simulation sqlite file. + + Returns + ------- + file : `str` + Path to the baseline cadence simulation sqlite file. + """ + dd = get_data_dir() + path = os.path.join(dd, "sim_baseline") + file = glob.glob(path + "/*10yrs.db")[0] + return file + + +def data_versions(): + """Get the dictionary of source filenames in the rubin_sim data directory. + + Returns + ------- + result : `dict` + Data directory filenames dictionary with keys: + ``"name"`` + Data bucket name (`str`). + ``"version"`` + Versioned file name (`str`). + """ + data_dir = get_data_dir() + result = None + version_file = os.path.join(data_dir, "versions.txt") + if os.path.isfile(version_file): + with open(version_file) as f: + content = f.readlines() + content = [x.strip() for x in content] + result = {} + for line in content: + ack = line.split(",") + result[ack[0]] = ack[1] + + return result diff --git a/rubin_scheduler/data/rs_download_sky.py b/rubin_scheduler/data/rs_download_sky.py new file mode 100644 index 00000000..17873797 --- /dev/null +++ b/rubin_scheduler/data/rs_download_sky.py @@ -0,0 +1,91 @@ +__all__ = ("MyHTMLParser", "rs_download_sky") + +import argparse +import os +from html.parser import HTMLParser + +import requests + +from . import get_data_dir + + +# Hack it up to find the filenames ending with .h5 +class MyHTMLParser(HTMLParser): + """HTML parser class that uses the HTMLParser to parse a starttag. + + See Also + -------- + html.parser.HTMLParser + + Examples + -------- + To instantiate a MyHTMLParser instance: + + parser = MyHTMLParser() + parser.handle_starttag(tag, attrs) + """ + + def handle_starttag(self, tag, attrs): + """ + Handle the start tag of an element (e.g.
). + + Parameters + ---------- + tag : `str` + The name of the tag converted to lower case. + attrs : `list` + A list of (name, value) pairs containing the attributes + found inside the tag’s <> brackets + """ + try: + self.filenames + except AttributeError: + setattr(self, "filenames", []) + if tag == "a": + if attrs[0][0] == "href": + if attrs[0][1].endswith(".h5"): + self.filenames.append(attrs[0][1]) + + +def rs_download_sky(): + """Download sky files.""" + parser = argparse.ArgumentParser( + description="Download precomputed skybrightness files for rubin_sim package" + ) + parser.add_argument( + "-f", + "--force", + dest="force", + default=False, + action="store_true", + help="Force re-download of sky brightness data.", + ) + parser.add_argument( + "--url_base", + type=str, + default="https://s3df.slac.stanford.edu/groups/rubin/static/sim-data/sims_skybrightness_pre/h5_2023_09_12/", + help="Root URL of download location", + ) + args = parser.parse_args() + + data_dir = get_data_dir() + destination = os.path.join(data_dir, "skybrightness_pre") + if not os.path.isdir(data_dir): + os.mkdir(data_dir) + if not os.path.isdir(destination): + os.mkdir(destination) + + # Get the index file + r = requests.get(args.url_base) + # Find the filenames + parser = MyHTMLParser() + parser.feed(r.text) + parser.close() + # Copy the sky data files, if they're not already present + for file in parser.filenames: + if not os.path.isfile(os.path.join(destination, file)) or args.force: + url = args.url_base + file + print(f"Downloading file {file} from {url}") + r = requests.get(url) + with open(os.path.join(destination, file), "wb") as f: + f.write(r.content) diff --git a/rubin_scheduler/data/scheduler_download_data.py b/rubin_scheduler/data/scheduler_download_data.py new file mode 100644 index 00000000..2b9c5251 --- /dev/null +++ b/rubin_scheduler/data/scheduler_download_data.py @@ -0,0 +1,177 @@ +__all__ = ("data_dict", "scheduler_download_data") + +import argparse +import os +import warnings +from shutil import rmtree, unpack_archive + +import requests +from requests.exceptions import ConnectionError +from tqdm.auto import tqdm + +from .data_sets import data_versions, get_data_dir + +DEFAULT_DATA_URL = "https://s3df.slac.stanford.edu/data/rubin/sim-data/rubin_sim_data/" + + +def data_dict(): + """Creates a `dict` for all data buckets and the tar file they map to. + To create tar files and follow any sym links, run: + ``tar -chvzf maf_may_2021.tgz maf`` + + Returns + ------- + result : `dict` + Data bucket filenames dictionary with keys: + ``"name"`` + Data bucket name (`str`). + ``"version"`` + Versioned file name (`str`). + """ + file_dict = { + "scheduler": "scheduler_2023_10_16.tgz", + "site_models": "site_models_2023_10_02.tgz", + "skybrightness_pre": "skybrightness_pre_2023_10_17.tgz", + "tests": "tests_2022_10_18.tgz", + } + return file_dict + + +def scheduler_download_data(file_dict=None): + """Download data.""" + + if file_dict is None: + file_dict = data_dict() + parser = argparse.ArgumentParser(description="Download data files for rubin_sim package") + parser.add_argument( + "--versions", + dest="versions", + default=False, + action="store_true", + help="Report expected versions, then quit", + ) + parser.add_argument( + "-d", + "--dirs", + type=str, + default=None, + help="Comma-separated list of directories to download", + ) + parser.add_argument( + "-f", + "--force", + dest="force", + default=False, + action="store_true", + help="Force re-download of data directory(ies)", + ) + parser.add_argument( + "--url_base", + type=str, + default=DEFAULT_DATA_URL, + help="Root URL of download location", + ) + parser.add_argument( + "--orbits_pre", + dest="orbits", + default=False, + action="store_true", + help="Include pre-computed orbit files.", + ) + parser.add_argument( + "--tdqm_disable", + dest="tdqm_disable", + default=False, + action="store_true", + help="Turn off tdqm progress bar", + ) + args = parser.parse_args() + + dirs = args.dirs + if dirs is None: + dirs = file_dict.keys() + else: + dirs = dirs.split(",") + + data_dir = get_data_dir() + if not os.path.isdir(data_dir): + os.mkdir(data_dir) + version_file = os.path.join(data_dir, "versions.txt") + versions = data_versions() + if versions is None: + versions = {} + + if args.versions: + print("Versions on disk currently // versions expected for this release:") + match = True + for k in file_dict: + print(f"{k} : {versions.get(k, '')} // {file_dict[k]}") + if versions.get(k, "") != file_dict[k]: + match = False + if match: + print("Versions are in sync") + return 0 + else: + print("Versions do not match") + return 1 + + if not args.orbits: + dirs = [key for key in dirs if "orbits_precompute" not in key] + + # See if base URL is alive + url_base = args.url_base + try: + r = requests.get(url_base) + fail_message = f"Could not connect to {args.url_base} or {url_base}. Check sites are up?" + except ConnectionError: + print(fail_message) + exit() + if r.status_code != requests.codes.ok: + print(fail_message) + exit() + + for key in dirs: + filename = file_dict[key] + path = os.path.join(data_dir, key) + if os.path.isdir(path) and not args.force: + warnings.warn("Directory %s already exists, skipping download" % path) + else: + if os.path.isdir(path) and args.force: + rmtree(path) + warnings.warn("Removed existing directory %s, downloading new copy" % path) + # Download file + url = url_base + filename + print("Downloading file: %s" % url) + # Stream and write in chunks (avoid large memory usage) + r = requests.get(url, stream=True) + file_size = int(r.headers.get("Content-Length", 0)) + if file_size < 245: + warnings.warn(f"{url} file size unexpectedly small.") + # Download this size chunk at a time; reasonable guess + block_size = 1024 * 1024 + progress_bar = tqdm(total=file_size, unit="iB", unit_scale=True, disable=args.tdqm_disable) + print(f"Writing to {os.path.join(data_dir, filename)}") + with open(os.path.join(data_dir, filename), "wb") as f: + for chunk in r.iter_content(chunk_size=block_size): + progress_bar.update(len(chunk)) + f.write(chunk) + progress_bar.close() + # untar in place + unpack_archive(os.path.join(data_dir, filename), data_dir) + os.remove(os.path.join(data_dir, filename)) + versions[key] = file_dict[key] + + # Write out the new version info to the data directory + with open(version_file, "w") as f: + for key in versions: + print(key + "," + versions[key], file=f) + + # Write a little table to stdout + new_versions = data_versions() + print("Current/updated data versions:") + for k in new_versions: + if len(k) <= 10: + sep = "\t\t" + else: + sep = "\t" + print(f"{k}{sep}{new_versions[k]}") diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..f8de51c8 --- /dev/null +++ b/setup.py @@ -0,0 +1,4 @@ +import setuptools_scm +from setuptools import setup + +setup(version=setuptools_scm.get_version()) diff --git a/test-requirements.txt b/test-requirements.txt new file mode 100644 index 00000000..fd467032 --- /dev/null +++ b/test-requirements.txt @@ -0,0 +1,5 @@ +pytest +pytest-cov +pytest-black +black +ruff diff --git a/tests/data/test_data.py b/tests/data/test_data.py new file mode 100644 index 00000000..4423dd0f --- /dev/null +++ b/tests/data/test_data.py @@ -0,0 +1,18 @@ +import unittest + +from rubin_scheduler.data import data_versions, get_data_dir + + +class DataTest(unittest.TestCase): + def testData(self): + """ + Check that basic data tools work + """ + data_dir = get_data_dir() + versions = data_versions() + assert data_dir is not None + assert versions is not None + + +if __name__ == "__main__": + unittest.main()