diff --git a/.github/workflows/build_docs.yaml b/.github/workflows/build_docs.yaml
index cf983fa1..c81a7b15 100644
--- a/.github/workflows/build_docs.yaml
+++ b/.github/workflows/build_docs.yaml
@@ -42,7 +42,7 @@ jobs:
shell: bash -l {0}
run: |
export RUBIN_SIM_DATA_DIR=${{ github.workspace }}
- rs_download_data --tdqm_disable -d 'scheduler'
+ scheduler_download_data --tdqm_disable -d 'scheduler'
- name: check conda and documenteer
shell: bash -l {0}
diff --git a/.github/workflows/build_pypi.yaml b/.github/workflows/build_pypi.yaml
index f63d57fd..e32a5c10 100644
--- a/.github/workflows/build_pypi.yaml
+++ b/.github/workflows/build_pypi.yaml
@@ -46,8 +46,8 @@ jobs:
shell: bash -l {0}
run: |
export RUBIN_SIM_DATA_DIR=${{ github.workspace }}/data_dir
- rs_download_data --tdqm_disable -d 'site_models,skybrightness_pre,scheduler'
- rs_download_data --tdqm_disable -d tests --force
+ scheduler_download_data --tdqm_disable -d 'site_models,skybrightness_pre,scheduler'
+ scheduler_download_data --tdqm_disable -d tests --force
- name: conda list
shell: bash -l {0}
diff --git a/.github/workflows/run_all_tests.yaml b/.github/workflows/run_all_tests.yaml
index 223705ce..72d4739f 100644
--- a/.github/workflows/run_all_tests.yaml
+++ b/.github/workflows/run_all_tests.yaml
@@ -47,7 +47,7 @@ jobs:
shell: bash -l {0}
run: |
export RUBIN_SIM_DATA_DIR=${{ github.workspace }}/data_dir
- rs_download_data --force --tdqm_disable
+ scheduler_download_data --force --tdqm_disable
- name: conda list
shell: bash -l {0}
diff --git a/.github/workflows/run_tests_docs.yaml b/.github/workflows/run_tests_docs.yaml
index cb14165a..1414e528 100644
--- a/.github/workflows/run_tests_docs.yaml
+++ b/.github/workflows/run_tests_docs.yaml
@@ -49,8 +49,8 @@ jobs:
shell: bash -l {0}
run: |
export RUBIN_SIM_DATA_DIR=${{ github.workspace }}/data_dir
- rs_download_data --tdqm_disable -d 'site_models,scheduler,skybrightness_pre'
- rs_download_data --tdqm_disable -d tests --force
+ scheduler_download_data --tdqm_disable -d 'site_models,scheduler,skybrightness_pre'
+ scheduler_download_data --tdqm_disable -d tests --force
- name: conda list
shell: bash -l {0}
@@ -94,7 +94,7 @@ jobs:
shell: bash -l {0}
run: |
export RUBIN_SIM_DATA_DIR=${{ github.workspace }}
- rs_download_data --tdqm_disable -d 'scheduler'
+ scheduler_download_data --tdqm_disable -d 'scheduler'
- name: check conda and documenteer
shell: bash -l {0}
diff --git a/.gitignore b/.gitignore
index 68bc17f9..e09d322f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -146,6 +146,15 @@ dmypy.json
# Pyre type checker
.pyre/
+# version file
+version.py
+
+# pycharm files
+.idea/
+
+# vscode files
+.vscode/
+
# pytype static type analyzer
.pytype/
diff --git a/pyproject.toml b/pyproject.toml
index b672b140..eeb91eee 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,7 +21,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering :: Astronomy",
]
-urls = {documentation = "https://rubin-sim.lsst.io", repository = "https://github.com/lsst/rubin_sim" }
+urls = {documentation = "https://rubin-scheduler.lsst.io", repository = "https://github.com/lsst/rubin_scheduler" }
dynamic = [ "version" ]
dependencies = [
"numpy",
@@ -52,7 +52,7 @@ dev = [
]
[project.scripts]
-rs_download_data = "rubin_scheduler.data.rs_download_data:rs_download_data"
+scheduler_download_data = "rubin_scheduler.data.scheduler_download_data:scheduler_download_data"
rs_download_sky = "rubin_scheduler.data.rs_download_sky:rs_download_sky"
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 00000000..3fba1667
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,22 @@
+setuptools_scm
+setuptools_scm_git_archive
+numpy
+matplotlib
+healpy
+pandas
+numexpr
+palpy
+scipy
+sqlalchemy
+astropy
+pytables
+h5py
+openorb
+openorb-data-de405
+astroplan
+colorcet
+cycler
+george
+scikit-learn
+requests
+tqdm
diff --git a/rubin_scheduler/__init__.py b/rubin_scheduler/__init__.py
new file mode 100644
index 00000000..58f3ace6
--- /dev/null
+++ b/rubin_scheduler/__init__.py
@@ -0,0 +1 @@
+from .version import __version__
diff --git a/rubin_scheduler/data/__init__.py b/rubin_scheduler/data/__init__.py
new file mode 100644
index 00000000..052dbdd3
--- /dev/null
+++ b/rubin_scheduler/data/__init__.py
@@ -0,0 +1 @@
+from .data_sets import * # noqa: F403
diff --git a/rubin_scheduler/data/data_sets.py b/rubin_scheduler/data/data_sets.py
new file mode 100644
index 00000000..49a832b8
--- /dev/null
+++ b/rubin_scheduler/data/data_sets.py
@@ -0,0 +1,62 @@
+__all__ = ("get_data_dir", "data_versions", "get_baseline")
+
+import glob
+import os
+
+
+def get_data_dir():
+ """Get the location of the rubin_sim data directory.
+
+ Returns
+ -------
+ data_dir : `str`
+ Path to the rubin_sim data directory.
+ """
+ # See if there is an environment variable with the path
+ data_dir = os.getenv("RUBIN_SIM_DATA_DIR")
+
+ # Set the root data directory
+ if data_dir is None:
+ data_dir = os.path.join(os.getenv("HOME"), "rubin_sim_data")
+ return data_dir
+
+
+def get_baseline():
+ """Get the path to the baseline cadence simulation sqlite file.
+
+ Returns
+ -------
+ file : `str`
+ Path to the baseline cadence simulation sqlite file.
+ """
+ dd = get_data_dir()
+ path = os.path.join(dd, "sim_baseline")
+ file = glob.glob(path + "/*10yrs.db")[0]
+ return file
+
+
+def data_versions():
+ """Get the dictionary of source filenames in the rubin_sim data directory.
+
+ Returns
+ -------
+ result : `dict`
+ Data directory filenames dictionary with keys:
+ ``"name"``
+ Data bucket name (`str`).
+ ``"version"``
+ Versioned file name (`str`).
+ """
+ data_dir = get_data_dir()
+ result = None
+ version_file = os.path.join(data_dir, "versions.txt")
+ if os.path.isfile(version_file):
+ with open(version_file) as f:
+ content = f.readlines()
+ content = [x.strip() for x in content]
+ result = {}
+ for line in content:
+ ack = line.split(",")
+ result[ack[0]] = ack[1]
+
+ return result
diff --git a/rubin_scheduler/data/rs_download_sky.py b/rubin_scheduler/data/rs_download_sky.py
new file mode 100644
index 00000000..17873797
--- /dev/null
+++ b/rubin_scheduler/data/rs_download_sky.py
@@ -0,0 +1,91 @@
+__all__ = ("MyHTMLParser", "rs_download_sky")
+
+import argparse
+import os
+from html.parser import HTMLParser
+
+import requests
+
+from . import get_data_dir
+
+
+# Hack it up to find the filenames ending with .h5
+class MyHTMLParser(HTMLParser):
+ """HTML parser class that uses the HTMLParser to parse a starttag.
+
+ See Also
+ --------
+ html.parser.HTMLParser
+
+ Examples
+ --------
+ To instantiate a MyHTMLParser instance:
+
+ parser = MyHTMLParser()
+ parser.handle_starttag(tag, attrs)
+ """
+
+ def handle_starttag(self, tag, attrs):
+ """
+ Handle the start tag of an element (e.g.
).
+
+ Parameters
+ ----------
+ tag : `str`
+ The name of the tag converted to lower case.
+ attrs : `list`
+ A list of (name, value) pairs containing the attributes
+ found inside the tag’s <> brackets
+ """
+ try:
+ self.filenames
+ except AttributeError:
+ setattr(self, "filenames", [])
+ if tag == "a":
+ if attrs[0][0] == "href":
+ if attrs[0][1].endswith(".h5"):
+ self.filenames.append(attrs[0][1])
+
+
+def rs_download_sky():
+ """Download sky files."""
+ parser = argparse.ArgumentParser(
+ description="Download precomputed skybrightness files for rubin_sim package"
+ )
+ parser.add_argument(
+ "-f",
+ "--force",
+ dest="force",
+ default=False,
+ action="store_true",
+ help="Force re-download of sky brightness data.",
+ )
+ parser.add_argument(
+ "--url_base",
+ type=str,
+ default="https://s3df.slac.stanford.edu/groups/rubin/static/sim-data/sims_skybrightness_pre/h5_2023_09_12/",
+ help="Root URL of download location",
+ )
+ args = parser.parse_args()
+
+ data_dir = get_data_dir()
+ destination = os.path.join(data_dir, "skybrightness_pre")
+ if not os.path.isdir(data_dir):
+ os.mkdir(data_dir)
+ if not os.path.isdir(destination):
+ os.mkdir(destination)
+
+ # Get the index file
+ r = requests.get(args.url_base)
+ # Find the filenames
+ parser = MyHTMLParser()
+ parser.feed(r.text)
+ parser.close()
+ # Copy the sky data files, if they're not already present
+ for file in parser.filenames:
+ if not os.path.isfile(os.path.join(destination, file)) or args.force:
+ url = args.url_base + file
+ print(f"Downloading file {file} from {url}")
+ r = requests.get(url)
+ with open(os.path.join(destination, file), "wb") as f:
+ f.write(r.content)
diff --git a/rubin_scheduler/data/scheduler_download_data.py b/rubin_scheduler/data/scheduler_download_data.py
new file mode 100644
index 00000000..2b9c5251
--- /dev/null
+++ b/rubin_scheduler/data/scheduler_download_data.py
@@ -0,0 +1,177 @@
+__all__ = ("data_dict", "scheduler_download_data")
+
+import argparse
+import os
+import warnings
+from shutil import rmtree, unpack_archive
+
+import requests
+from requests.exceptions import ConnectionError
+from tqdm.auto import tqdm
+
+from .data_sets import data_versions, get_data_dir
+
+DEFAULT_DATA_URL = "https://s3df.slac.stanford.edu/data/rubin/sim-data/rubin_sim_data/"
+
+
+def data_dict():
+ """Creates a `dict` for all data buckets and the tar file they map to.
+ To create tar files and follow any sym links, run:
+ ``tar -chvzf maf_may_2021.tgz maf``
+
+ Returns
+ -------
+ result : `dict`
+ Data bucket filenames dictionary with keys:
+ ``"name"``
+ Data bucket name (`str`).
+ ``"version"``
+ Versioned file name (`str`).
+ """
+ file_dict = {
+ "scheduler": "scheduler_2023_10_16.tgz",
+ "site_models": "site_models_2023_10_02.tgz",
+ "skybrightness_pre": "skybrightness_pre_2023_10_17.tgz",
+ "tests": "tests_2022_10_18.tgz",
+ }
+ return file_dict
+
+
+def scheduler_download_data(file_dict=None):
+ """Download data."""
+
+ if file_dict is None:
+ file_dict = data_dict()
+ parser = argparse.ArgumentParser(description="Download data files for rubin_sim package")
+ parser.add_argument(
+ "--versions",
+ dest="versions",
+ default=False,
+ action="store_true",
+ help="Report expected versions, then quit",
+ )
+ parser.add_argument(
+ "-d",
+ "--dirs",
+ type=str,
+ default=None,
+ help="Comma-separated list of directories to download",
+ )
+ parser.add_argument(
+ "-f",
+ "--force",
+ dest="force",
+ default=False,
+ action="store_true",
+ help="Force re-download of data directory(ies)",
+ )
+ parser.add_argument(
+ "--url_base",
+ type=str,
+ default=DEFAULT_DATA_URL,
+ help="Root URL of download location",
+ )
+ parser.add_argument(
+ "--orbits_pre",
+ dest="orbits",
+ default=False,
+ action="store_true",
+ help="Include pre-computed orbit files.",
+ )
+ parser.add_argument(
+ "--tdqm_disable",
+ dest="tdqm_disable",
+ default=False,
+ action="store_true",
+ help="Turn off tdqm progress bar",
+ )
+ args = parser.parse_args()
+
+ dirs = args.dirs
+ if dirs is None:
+ dirs = file_dict.keys()
+ else:
+ dirs = dirs.split(",")
+
+ data_dir = get_data_dir()
+ if not os.path.isdir(data_dir):
+ os.mkdir(data_dir)
+ version_file = os.path.join(data_dir, "versions.txt")
+ versions = data_versions()
+ if versions is None:
+ versions = {}
+
+ if args.versions:
+ print("Versions on disk currently // versions expected for this release:")
+ match = True
+ for k in file_dict:
+ print(f"{k} : {versions.get(k, '')} // {file_dict[k]}")
+ if versions.get(k, "") != file_dict[k]:
+ match = False
+ if match:
+ print("Versions are in sync")
+ return 0
+ else:
+ print("Versions do not match")
+ return 1
+
+ if not args.orbits:
+ dirs = [key for key in dirs if "orbits_precompute" not in key]
+
+ # See if base URL is alive
+ url_base = args.url_base
+ try:
+ r = requests.get(url_base)
+ fail_message = f"Could not connect to {args.url_base} or {url_base}. Check sites are up?"
+ except ConnectionError:
+ print(fail_message)
+ exit()
+ if r.status_code != requests.codes.ok:
+ print(fail_message)
+ exit()
+
+ for key in dirs:
+ filename = file_dict[key]
+ path = os.path.join(data_dir, key)
+ if os.path.isdir(path) and not args.force:
+ warnings.warn("Directory %s already exists, skipping download" % path)
+ else:
+ if os.path.isdir(path) and args.force:
+ rmtree(path)
+ warnings.warn("Removed existing directory %s, downloading new copy" % path)
+ # Download file
+ url = url_base + filename
+ print("Downloading file: %s" % url)
+ # Stream and write in chunks (avoid large memory usage)
+ r = requests.get(url, stream=True)
+ file_size = int(r.headers.get("Content-Length", 0))
+ if file_size < 245:
+ warnings.warn(f"{url} file size unexpectedly small.")
+ # Download this size chunk at a time; reasonable guess
+ block_size = 1024 * 1024
+ progress_bar = tqdm(total=file_size, unit="iB", unit_scale=True, disable=args.tdqm_disable)
+ print(f"Writing to {os.path.join(data_dir, filename)}")
+ with open(os.path.join(data_dir, filename), "wb") as f:
+ for chunk in r.iter_content(chunk_size=block_size):
+ progress_bar.update(len(chunk))
+ f.write(chunk)
+ progress_bar.close()
+ # untar in place
+ unpack_archive(os.path.join(data_dir, filename), data_dir)
+ os.remove(os.path.join(data_dir, filename))
+ versions[key] = file_dict[key]
+
+ # Write out the new version info to the data directory
+ with open(version_file, "w") as f:
+ for key in versions:
+ print(key + "," + versions[key], file=f)
+
+ # Write a little table to stdout
+ new_versions = data_versions()
+ print("Current/updated data versions:")
+ for k in new_versions:
+ if len(k) <= 10:
+ sep = "\t\t"
+ else:
+ sep = "\t"
+ print(f"{k}{sep}{new_versions[k]}")
diff --git a/setup.py b/setup.py
new file mode 100644
index 00000000..f8de51c8
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,4 @@
+import setuptools_scm
+from setuptools import setup
+
+setup(version=setuptools_scm.get_version())
diff --git a/test-requirements.txt b/test-requirements.txt
new file mode 100644
index 00000000..fd467032
--- /dev/null
+++ b/test-requirements.txt
@@ -0,0 +1,5 @@
+pytest
+pytest-cov
+pytest-black
+black
+ruff
diff --git a/tests/data/test_data.py b/tests/data/test_data.py
new file mode 100644
index 00000000..4423dd0f
--- /dev/null
+++ b/tests/data/test_data.py
@@ -0,0 +1,18 @@
+import unittest
+
+from rubin_scheduler.data import data_versions, get_data_dir
+
+
+class DataTest(unittest.TestCase):
+ def testData(self):
+ """
+ Check that basic data tools work
+ """
+ data_dir = get_data_dir()
+ versions = data_versions()
+ assert data_dir is not None
+ assert versions is not None
+
+
+if __name__ == "__main__":
+ unittest.main()