Skip to content

Commit

Permalink
feat(ingest): allow option to group segments using source of truth
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-parker committed Jan 30, 2025
1 parent 72dd94a commit 0aeed39
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 10 deletions.
64 changes: 56 additions & 8 deletions ingest/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ FILTER_FASTA_HEADERS = config.get("filter_fasta_headers", None)
APPROVE_TIMEOUT_MIN = config.get("approve_timeout_min") # time in minutes
CHECK_ENA_DEPOSITION = config.get("check_ena_deposition", False)
ALIGN = True
GROUPS_JSON = None

dataset_server_map = {}
dataset_name_map = {}

if SEGMENTED:
GROUPS_JSON = config.get("grouping_ground_truth", None) # JSON map from group to segments' insdcAccessionFull
if config.get("minimizer_index") and config.get("minimizer_parser"):
ALIGN = False
if ALIGN:
Expand Down Expand Up @@ -79,7 +81,7 @@ rule fetch_ncbi_dataset_package:
datasets download virus genome taxon {params.taxon_id} \
--no-progressbar \
--filename {output.dataset_package} \
--api-key {params.api_key} \
--api-key {params.api_key}\
"""


Expand Down Expand Up @@ -255,7 +257,7 @@ if ALIGN:
"""

if not ALIGN:
rule download:
rule download_minimizer:
output:
results="results/minimzer.json",
params:
Expand All @@ -275,7 +277,7 @@ if not ALIGN:
"""
nextclade sort -m {input.minimizer} \
-r {output.results} {input.sequences} \
--max-score-gap 0.3 --min-score 0.1 --min-hits 2 --all-matches
--max-score-gap 0.3 --min-score 0.05 --min-hits 2 --all-matches
"""

rule parse_sort:
Expand Down Expand Up @@ -324,18 +326,60 @@ rule prepare_metadata:
--log-level {params.log_level} \
"""

if GROUPS_JSON:
rule download_groups:
output:
results="results/groups.json",
params:
grouping=config.get("grouping_ground_truth")
shell:
"""
curl -L -o {output.results} {params.grouping}
"""

rule group_segments:
input:
script="scripts/deterministic_group_segments.py",
metadata="results/metadata_post_prepare.ndjson",
sequences="results/sequences.ndjson",
config="results/config.yaml",
groups="results/groups.json",
output:
metadata="results/metadata_grouped.ndjson",
sequences="results/sequences_grouped.ndjson",
ungrouped_metadata="results/metadata_ungrouped.ndjson",
ungrouped_sequences="results/sequences_ungrouped.ndjson",
params:
log_level=LOG_LEVEL,
shell:
"""
python {input.script} \
--config-file {input.config} \
--groups {input.groups} \
--input-metadata {input.metadata} \
--input-seq {input.sequences} \
--output-metadata {output.metadata} \
--output-ungrouped-metadata {output.ungrouped_metadata} \
--output-ungrouped-seq {output.ungrouped_sequences} \
--output-seq {output.sequences} \
--log-level {params.log_level} \
"""


rule group_segments:
rule heuristic_group_segments:
input:
script="scripts/group_segments.py",
metadata="results/metadata_post_prepare.ndjson",
sequences="results/sequences.ndjson",
script="scripts/heuristic_group_segments.py",
metadata="results/metadata_ungrouped.ndjson" if GROUPS_JSON else "results/metadata_post_prepare.ndjson",
sequences="results/sequences_ungrouped.ndjson" if GROUPS_JSON else "results/sequences.ndjson",
metadata_grouped="results/metadata_grouped.ndjson" if GROUPS_JSON else "results/metadata_post_prepare.ndjson",
sequences_grouped="results/sequences_grouped.ndjson" if GROUPS_JSON else "results/sequences.ndjson",
config="results/config.yaml",
output:
metadata="results/metadata_post_group.ndjson",
sequences="results/sequences_post_group.ndjson",
params:
log_level=LOG_LEVEL,
groups_json="true" if GROUPS_JSON else "false"
shell:
"""
python {input.script} \
Expand All @@ -344,7 +388,11 @@ rule group_segments:
--input-seq {input.sequences} \
--output-metadata {output.metadata} \
--output-seq {output.sequences} \
--log-level {params.log_level} \
--log-level {params.log_level}
if [ "{params.groups_json}" = "true" ]; then
cat {input.metadata_grouped} >> {output.metadata}
cat {input.sequences_grouped} >> {output.sequences}
fi
"""


Expand Down
259 changes: 259 additions & 0 deletions ingest/scripts/deterministic_group_segments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
"""Script to group segments together into sequence entries prior to submission to Loculus
given a json with known groups.
Example output for a single isolate with 3 segments:
{"id": "KJ682796.1.L/KJ682809.1.M/KJ682819.1.S",
"metadata": {
"ncbiReleaseDate": "2014-07-06T00:00:00Z",
"ncbiSourceDb": "GenBank",
"authors": "D. Goedhals, F.J. Burt, J.T. Bester, R. Swanepoel",
"insdcVersion_L": "1",
"insdcVersion_M": "1",
"insdcVersion_S": "1",
"insdcAccessionFull_L": "KJ682796.1",
"insdcAccessionFull_M": "KJ682809.1",
"insdcAccessionFull_S": "KJ682819.1",
"hash_L": "ddbfc33d45267e9c1a08f8f5e76d3e39",
"hash_M": "f64777883ba9f5293257698255767f2c",
"hash_S": "f716ed13dca9c8a033d46da2f3dc2ff1",
"hash": "ce7056d0bd7e3d6d3eca38f56b9d10f8",
"submissionId": "KJ682796.1.L/KJ682809.1.M/KJ682819.1.S"
}}"""

import hashlib
import json
import logging
import pathlib
from dataclasses import dataclass
from typing import Final

import click
import orjsonl
import yaml

logger = logging.getLogger(__name__)
logging.basicConfig(
encoding="utf-8",
level=logging.DEBUG,
format="%(asctime)s %(levelname)8s (%(filename)20s:%(lineno)4d) - %(message)s ",
datefmt="%H:%M:%S",
)


@dataclass(frozen=True)
class Config:
compound_country_field: str
fasta_id_field: str
insdc_segment_specific_fields: list[str] # What does this field mean?
nucleotide_sequences: list[str]
segmented: bool


# submissionId is actually NCBI accession
INTRINSICALLY_SEGMENT_SPECIFIC_FIELDS: Final = {"segment", "submissionId"}


def sort_authors(authors: str) -> str:
"""Sort authors alphabetically"""
return "; ".join(sorted([author.strip() for author in authors.split(";")]))


def group_records(record_list, output_metadata, fasta_id_map, config, different_values_log={}):
# Assert that all records are from a different segment
for segment in config.nucleotide_sequences:
if len([record for record in record_list if record["metadata"]["segment"] == segment]) > 1:
# raise ValueError("Cannot group multiple records from the same segment" + ", ".join([record["id"] for record in record_list]))
logger.error(
"Cannot group multiple records from the same segment"
+ ", ".join([record["id"] for record in record_list])
)
# write record list to a file
orjsonl.append("results/errors.ndjson", record_list)
return
segment_map = {record["metadata"]["segment"]: record["metadata"] for record in record_list}

all_fields = sorted(record_list[0]["metadata"].keys())

# Metadata fields can vary between segments w/o indicating being from different assemblies
insdc_segment_specific_fields = set(config.insdc_segment_specific_fields)
insdc_segment_specific_fields.add("hash")

# Fields that in principle should be identical for all segments of the same assembly
shared_fields = sorted(
set(all_fields) - insdc_segment_specific_fields - INTRINSICALLY_SEGMENT_SPECIFIC_FIELDS
)

grouped_metadata = {}
for key in shared_fields:
if key in {"authors", "authorAffiliations"}:
values = [sort_authors(d["metadata"][key]) for d in record_list]
else:
values = [d["metadata"][key] for d in record_list]
if len(set(values)) > 1:
different_values_log[key] = different_values_log.get(key, 0) + 1
if len(set(values)) == 2 and "" in set(values):
grouped_metadata[key] = next(iter(set(values) - {""}))
continue
if key in {"authors", "authorAffiliations"}:
grouped_metadata[key] = values[0]
continue
orjsonl.append(
"results/warnings.ndjson",
{
"accessions": [record["id"] for record in record_list],
"field": key,
"values": values,
},
)
grouped_metadata[key] = record_list[0]["metadata"][key]
for field in insdc_segment_specific_fields:
for segment in config.nucleotide_sequences:
grouped_metadata[f"{field}_{segment}"] = (
segment_map[segment][field] if segment in segment_map else ""
)

joint_key = "/".join(
[
f"{segment_map[segment]['insdcAccessionFull']}.{segment}"
for segment in config.nucleotide_sequences
if segment in segment_map
]
)
grouped_metadata["submissionId"] = joint_key

# Hash of all metadata fields should be the same if
# 1. field is not in keys_to_keep and
# 2. field is in keys_to_keep but is "" or None
filtered_record = {k: str(v) for k, v in grouped_metadata.items() if v is not None and str(v)}

grouped_metadata["hash"] = hashlib.md5(
json.dumps(filtered_record, sort_keys=True).encode(), usedforsecurity=False
).hexdigest()

for segment in segment_map:
accession = segment_map[segment]["insdcAccessionFull"]
fasta_id_map[accession] = f"{joint_key}_{segment}"

orjsonl.append(output_metadata, {"id": joint_key, "metadata": grouped_metadata})


@click.command()
@click.option("--config-file", required=True, type=click.Path(exists=True))
@click.option("--groups", required=True, type=click.Path(exists=True))
@click.option("--input-seq", required=True, type=click.Path(exists=True))
@click.option("--input-metadata", required=True, type=click.Path(exists=True))
@click.option("--output-seq", required=True, type=click.Path())
@click.option("--output-metadata", required=True, type=click.Path())
@click.option("--output-ungrouped-seq", required=True, type=click.Path())
@click.option("--output-ungrouped-metadata", required=True, type=click.Path())
@click.option(
"--log-level",
default="INFO",
type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
)
def main(
config_file: str,
groups: str,
input_seq: str,
input_metadata: str,
output_seq: str,
output_metadata: str,
output_ungrouped_seq: str,
output_ungrouped_metadata: str,
log_level: str,
) -> None:
logger.setLevel(log_level)
logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)

full_config = yaml.safe_load(pathlib.Path(config_file).read_text(encoding="utf-8"))
relevant_config = {key: full_config[key] for key in Config.__annotations__}
config = Config(**relevant_config)
logger.info(config)

if not config.segmented:
raise ValueError({"ERROR: You are running a function that requires segmented data"})

logger.info(f"Reading metadata from {input_metadata}")

source_of_truth = json.load(open(groups, encoding="utf-8"))
logger.info(f"Found {len(source_of_truth.keys())} source of truth groups")
accession_to_group = {}
for group, metadata in source_of_truth.items():
for accession in metadata:
accession_to_group[accession] = group

found_groups = {group: [] for group in source_of_truth}
# Map from original accession to the new concatenated accession
type Accession = str
type SubmissionId = str
fasta_id_map: dict[Accession, SubmissionId] = {}
ungrouped_accessions = set()
different_values_log = {}
count_total = 0
count_ungrouped = 0
for record in orjsonl.stream(input_metadata):
count_total += 1
metadata = record["metadata"]
if metadata["insdcAccessionFull"] not in accession_to_group:
count_ungrouped += 1
orjsonl.append(
output_ungrouped_metadata, {"id": record["id"], "metadata": record["metadata"]}
)
ungrouped_accessions.add(record["id"])
continue
group = accession_to_group[metadata["insdcAccessionFull"]]
found_groups[group].append(record)
if len(found_groups[group]) == len(set(source_of_truth[group])):
group_records(
found_groups[group], output_metadata, fasta_id_map, config, different_values_log
)
del found_groups[group]

logger.info(f"Found {count_total} records")
logger.info(f"Unable to group {count_ungrouped} records")

# add found_groups without all segments in file
count_unfilled_groups = 0
count_missing_tests = 0
for name, records in found_groups.items():
count_unfilled_groups += 1
logger.debug(
f"{name}: Missing record {set(source_of_truth[name]) - {record['metadata']['insdcAccessionFull'] for record in records}}"
)
if len(records) == 0:
count_missing_tests += 1
continue
group_records(records, output_metadata, fasta_id_map, config, different_values_log)
logger.info(different_values_log)
logger.info(f"Found {count_unfilled_groups} groups without all segments")
logger.info(f"Found {count_missing_tests} groups without any segments")

count_grouped = 0
count_ungrouped = 0
count_ignored = 0
for record in orjsonl.stream(input_seq):
accession = record["id"]
raw_sequence = record["sequence"]
if accession in ungrouped_accessions:
orjsonl.append(output_ungrouped_seq, {"id": accession, "sequence": raw_sequence})
count_ungrouped += 1
continue
if accession not in fasta_id_map:
count_ignored += 1
continue
orjsonl.append(
output_seq,
{
"id": fasta_id_map[accession],
"sequence": raw_sequence,
},
)
count_grouped += 1
logger.info(f"Wrote {count_grouped} grouped sequences")
logger.info(f"Wrote {count_ungrouped} ungrouped sequences")
logger.info(f"Ignored {count_ignored} sequences as not found in {input_seq}")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import pathlib
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Final

import click
Expand Down
2 changes: 1 addition & 1 deletion ingest/tests/test_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_snakemake():
copy_files(source_directory, destination_directory)
run_snakemake("extract_ncbi_dataset_sequences", touch=True) # Ignore sequences for now
run_snakemake("get_loculus_depositions", touch=True) # Do not call_loculus
run_snakemake("group_segments")
run_snakemake("heuristic_group_segments")
run_snakemake("get_previous_submissions", touch=True) # Do not call_loculus
run_snakemake("compare_hashes")
run_snakemake("prepare_files")
Expand Down
Loading

0 comments on commit 0aeed39

Please sign in to comment.