Skip to content

Commit

Permalink
YAML-run pipeline, part 2 (data preprocessing) (#28)
Browse files Browse the repository at this point in the history
_This resolves #26._

Until now, we've been missing the ability to easily configure how data
is preprocessed.

The `linmod.data` script now accepts a single commandline argument for
the path to a YAML file configuring its behavior. This is optional;
without it, the default behavior (seen in the dictionary
`linmod.data.DEFAULT_CONFIG`) will be used. The YAML file only needs to
define the keys it wants to modify from default; missing keys will be
populated with the default values.

An example is given in `present-day-forecasting/config.yaml`. As
described in the README, this is run as `python3 linmod.data
config.yaml`.

---------

Co-authored-by: afmagee42 <[email protected]>
  • Loading branch information
thanasibakis and afmagee42 authored Aug 15, 2024
1 parent 438ae92 commit 4947165
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 162 deletions.
2 changes: 2 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# API reference

::: linmod.data

::: linmod.models

::: linmod.eval
Expand Down
330 changes: 193 additions & 137 deletions linmod/data.py
Original file line number Diff line number Diff line change
@@ -1,141 +1,161 @@
"""
Usage: `python3 -m linmod.data`
Usage: `python3 -m linmod.data [path/to/config.yaml]`
Download the Nextstrain metadata file, preprocess it, and print the result to `stdout`.
Download the Nextstrain metadata file, preprocess it, and export it.
Two datasets are exported: one for model fitting and one for evaluation.
The model dataset contains sequences collected and reported by a specified
forecast date, while the evaluation dataset extends the horizon into the future.
To change default behaviors, create a yaml configuration file with the key ["data"],
and pass it in the call to this script. For a list of configurable sub-keys, see the
`DEFAULT_CONFIG` dictionary.
The output is given in CSV format, with columns `date`, `fd_offset`, `division`,
`lineage`, `count`. Rows are uniquely identified by `(date, division, lineage)`.
`date` and `fd_offset` can be computed from each other, given the forecast date;
the `fd_offset` column is the number of days between the forecast date
(defaults to today) and the `date` column.
Preprocessing is done to ensure that:
- The most recent 90 days of sequences are included;
- Observations without a recorded date are removed;
- Only the 50 U.S. states, D.C., and Puerto Rico are included; and
- Only observations from human hosts are included.
the `fd_offset` column is the number of days between the forecast date and the `date`
column, such that, for example, 0 is the forecast date, -1 the day before, and 1 the
day after.
The data is downloaded from:
https://data.nextstrain.org/files/ncov/open/metadata.tsv.zst
Note that observations without a recorded date are removed, and only observations
from human hosts are included.
"""

import lzma
import os
import sys
from datetime import datetime
from pathlib import Path
from urllib.parse import urlparse
from urllib.request import urlopen

import polars as pl
import yaml
import zstandard

# Configuration

# Where should the unprocessed (but decompressed) data be stored?
CACHE_DIRECTORY = Path(".cache")

# Where should the data be downloaded from?
DATA_SOURCE = "https://data.nextstrain.org/files/ncov/open/metadata.tsv.zst"

# What column should be renamed to `lineage`?
LINEAGE_COLUMN_NAME = "clade_nextstrain"

# How many days of sequences should be included?
NUM_DAYS = 90

# Which divisions should be included?
# Currently set to the 50 U.S. states, D.C., and Puerto Rico
INCLUDED_DIVISIONS = [
"Alabama",
"Alaska",
"Arizona",
"Arkansas",
"California",
"Colorado",
"Connecticut",
"Delaware",
"Florida",
"Georgia",
"Hawaii",
"Idaho",
"Illinois",
"Indiana",
"Iowa",
"Kansas",
"Kentucky",
"Louisiana",
"Maine",
"Maryland",
"Massachusetts",
"Michigan",
"Minnesota",
"Mississippi",
"Missouri",
"Montana",
"Nebraska",
"Nevada",
"New Hampshire",
"New Jersey",
"New Mexico",
"New York",
"North Carolina",
"North Dakota",
"Ohio",
"Oklahoma",
"Oregon",
"Pennsylvania",
"Puerto Rico",
"Rhode Island",
"South Carolina",
"South Dakota",
"Tennessee",
"Texas",
"Utah",
"Vermont",
"Virginia",
"Washington",
"Washington DC",
"West Virginia",
"Wisconsin",
"Wyoming",
]


def load_metadata(
forecast_date: tuple | None = None,
redownload: bool = False,
) -> pl.DataFrame:
"""
Download the metadata file, preprocess it, and return a `polars.DataFrame`.
The data is filtered to include only the most recent `NUM_DAYS` days of sequences
collected by `forecast_date`, specified as a tuple `(year, month, day)`
(defaulting to today's date if not specified). The column specified by
`LINEAGE_COLUMN_NAME` is renamed to `lineage`. The unprocessed (but decompressed)
data is cached in the `CACHE_DIRECTORY`. If `redownload`, the data is redownloaded,
and the cache is replaced.
"""

if forecast_date is None:
now = datetime.now()
forecast_date = (now.year, now.month, now.day)

parsed_url = urlparse(DATA_SOURCE)
save_path = (
CACHE_DIRECTORY
from .utils import print_message

DEFAULT_CONFIG = {
"data": {
# Where should the data be downloaded from?
"source": "https://data.nextstrain.org/files/ncov/open/metadata.tsv.zst",
# Where (directory) should the unprocessed (but decompressed) data be stored?
"cache_dir": ".cache/",
# Where (files) should the processed datasets for modeling and evaluation
# be stored?
"save_path": {
"model": "data/metadata-model.csv",
"eval": "data/metadata-eval.csv",
},
# Should the data be redownloaded (and the cache replaced)?
"redownload": False,
# What column should be renamed to `lineage`?
"lineage_column_name": "clade_nextstrain",
# What is the forecast date?
# No sequences collected or reported after this date are included in the
# modeling dataset.
"forecast_date": {
"year": datetime.now().year,
"month": datetime.now().month,
"day": datetime.now().day,
},
# How many days since the forecast date should be included in the datasets?
# The evaluation dataset will contain sequences collected and reported within
# this horizon. The modeling dataset will contain sequences collected and
# reported within the horizon `[lower, 0]`.
"horizon": {
"lower": -90,
"upper": 14,
},
# Which divisions should be included?
# Currently set to the 50 U.S. states, D.C., and Puerto Rico
"included_divisions": [
"Alabama",
"Alaska",
"Arizona",
"Arkansas",
"California",
"Colorado",
"Connecticut",
"Delaware",
"Florida",
"Georgia",
"Hawaii",
"Idaho",
"Illinois",
"Indiana",
"Iowa",
"Kansas",
"Kentucky",
"Louisiana",
"Maine",
"Maryland",
"Massachusetts",
"Michigan",
"Minnesota",
"Mississippi",
"Missouri",
"Montana",
"Nebraska",
"Nevada",
"New Hampshire",
"New Jersey",
"New Mexico",
"New York",
"North Carolina",
"North Dakota",
"Ohio",
"Oklahoma",
"Oregon",
"Pennsylvania",
"Puerto Rico",
"Rhode Island",
"South Carolina",
"South Dakota",
"Tennessee",
"Texas",
"Utah",
"Vermont",
"Virginia",
"Washington",
"Washington DC",
"West Virginia",
"Wisconsin",
"Wyoming",
],
}
}
"""
Default configuration for data download, preprocessing, and export.
The configuration dictionary expects all of the following entries in a
`data` key.
"""

if __name__ == "__main__":
# Load configuration, if given

config = DEFAULT_CONFIG

if len(sys.argv) > 1:
with open(sys.argv[1]) as f:
config["data"] |= yaml.safe_load(f)["data"]

# Download the data, if necessary

parsed_url = urlparse(config["data"]["source"])
cache_path = (
Path(config["data"]["cache_dir"])
/ parsed_url.netloc
/ parsed_url.path.lstrip("/").rsplit(".", 1)[0]
)
# TODO: should cache save path incorporate `forecast_date`?

# Download the data if necessary
if redownload or not os.path.exists(save_path):
print("Downloading...", file=sys.stderr, flush=True, end="")
if config["data"]["redownload"] or not cache_path.exists():
print_message("Downloading...", end="")

save_path.parent.mkdir(parents=True, exist_ok=True)
cache_path.parent.mkdir(parents=True, exist_ok=True)

with urlopen(DATA_SOURCE) as response, save_path.open(
with urlopen(config["data"]["source"]) as response, cache_path.open(
"wb"
) as out_file:
if parsed_url.path.endswith(".gz"):
Expand All @@ -151,48 +171,84 @@ def load_metadata(
else:
raise ValueError(f"Unsupported file format: {parsed_url.path}")

print(" done.", file=sys.stderr, flush=True)
print_message(" done.")
else:
print("Using cached data.", file=sys.stderr, flush=True)
print_message("Using cached data.")

# Preprocess and export the data

# Preprocess the data
print("Preprocessing data...", file=sys.stderr, flush=True, end="")
print_message("Exporting evaluation dataset...", end="")

forecast_date = pl.date(
config["data"]["forecast_date"]["year"],
config["data"]["forecast_date"]["month"],
config["data"]["forecast_date"]["day"],
)

horizon_lower_date = forecast_date.dt.offset_by(
f'{config["data"]["horizon"]["lower"]}d'
)
horizon_upper_date = forecast_date.dt.offset_by(
f'{config["data"]["horizon"]["upper"]}d'
)

df = (
pl.scan_csv(save_path, separator="\t")
.rename({LINEAGE_COLUMN_NAME: "lineage"})
full_df = (
pl.scan_csv(cache_path, separator="\t")
.rename({config["data"]["lineage_column_name"]: "lineage"})
# Cast with `strict=False` replaces invalid values with null,
# which we can then filter out. Invalid values include dates
# that are resolved only to the month, not the day
.cast({"date": pl.Date}, strict=False)
.cast({"date": pl.Date, "date_submitted": pl.Date}, strict=False)
.filter(
# Drop samples with missing collection or reporting dates
pl.col("date").is_not_null(),
pl.col("date") <= pl.date(*forecast_date),
pl.col("division").is_in(INCLUDED_DIVISIONS),
pl.col("date_submitted").is_not_null(),
# Drop samples collected outside the horizon
horizon_lower_date <= pl.col("date"),
pl.col("date") <= horizon_upper_date,
# Drop samples claiming to be reported before being collected
pl.col("date") <= pl.col("date_submitted"),
# Drop samples not from humans in the included US divisions
pl.col("division").is_in(config["data"]["included_divisions"]),
country="USA",
host="Homo sapiens",
)
.filter(
pl.col("date") >= pl.col("date").max() - 90,
)
.group_by("lineage", "date", "division")
)

eval_df = (
full_df.group_by("lineage", "date", "division")
.agg(pl.len().alias("count"))
.with_columns(
fd_offset=(
pl.col("date") - pl.date(*forecast_date)
).dt.total_days()
fd_offset=(pl.col("date") - forecast_date).dt.total_days()
)
.select("date", "fd_offset", "division", "lineage", "count")
.collect()
)

print(" done.", file=sys.stderr, flush=True)
Path(config["data"]["save_path"]["eval"]).parent.mkdir(
parents=True, exist_ok=True
)

return df
eval_df.write_csv(config["data"]["save_path"]["eval"])

print_message(" done.")
print_message("Exporting modeling dataset...", end="")

if __name__ == "__main__":
data = load_metadata()
# TODO: an argparse setup to specify `forecast_date`
model_df = (
full_df.filter(pl.col("date_submitted") <= forecast_date)
.group_by("lineage", "date", "division")
.agg(pl.len().alias("count"))
.with_columns(
fd_offset=(pl.col("date") - forecast_date).dt.total_days()
)
.select("date", "fd_offset", "division", "lineage", "count")
.collect()
)

Path(config["data"]["save_path"]["model"]).parent.mkdir(
parents=True, exist_ok=True
)

model_df.write_csv(config["data"]["save_path"]["model"])

print(data.collect().write_csv(), end="")
print("\nSuccess.", file=sys.stderr)
print_message(" done.")
Loading

0 comments on commit 4947165

Please sign in to comment.