diff --git a/ingest/Snakefile b/ingest/Snakefile index 3929a5ed4..417fc86fa 100644 --- a/ingest/Snakefile +++ b/ingest/Snakefile @@ -27,11 +27,13 @@ FILTER_FASTA_HEADERS = config.get("filter_fasta_headers", None) APPROVE_TIMEOUT_MIN = config.get("approve_timeout_min") # time in minutes CHECK_ENA_DEPOSITION = config.get("check_ena_deposition", False) ALIGN = True +GROUPS_JSON = None dataset_server_map = {} dataset_name_map = {} if SEGMENTED: + GROUPS_JSON = config.get("grouping_ground_truth", None) # JSON map from group to segments' insdcAccessionFull if config.get("minimizer_index") and config.get("minimizer_parser"): ALIGN = False if ALIGN: @@ -79,7 +81,7 @@ rule fetch_ncbi_dataset_package: datasets download virus genome taxon {params.taxon_id} \ --no-progressbar \ --filename {output.dataset_package} \ - --api-key {params.api_key} \ + --api-key {params.api_key}\ """ @@ -255,7 +257,7 @@ if ALIGN: """ if not ALIGN: - rule download: + rule download_minimizer: output: results="results/minimzer.json", params: @@ -275,7 +277,7 @@ if not ALIGN: """ nextclade sort -m {input.minimizer} \ -r {output.results} {input.sequences} \ - --max-score-gap 0.3 --min-score 0.1 --min-hits 2 --all-matches + --max-score-gap 0.3 --min-score 0.05 --min-hits 2 --all-matches """ rule parse_sort: @@ -324,18 +326,60 @@ rule prepare_metadata: --log-level {params.log_level} \ """ +if GROUPS_JSON: + rule download_groups: + output: + results="results/groups.json", + params: + grouping=config.get("grouping_ground_truth") + shell: + """ + curl -L -o {output.results} {params.grouping} + """ + + rule group_segments: + input: + script="scripts/deterministic_group_segments.py", + metadata="results/metadata_post_prepare.ndjson", + sequences="results/sequences.ndjson", + config="results/config.yaml", + groups="results/groups.json", + output: + metadata="results/metadata_grouped.ndjson", + sequences="results/sequences_grouped.ndjson", + ungrouped_metadata="results/metadata_ungrouped.ndjson", + ungrouped_sequences="results/sequences_ungrouped.ndjson", + params: + log_level=LOG_LEVEL, + shell: + """ + python {input.script} \ + --config-file {input.config} \ + --groups {input.groups} \ + --input-metadata {input.metadata} \ + --input-seq {input.sequences} \ + --output-metadata {output.metadata} \ + --output-ungrouped-metadata {output.ungrouped_metadata} \ + --output-ungrouped-seq {output.ungrouped_sequences} \ + --output-seq {output.sequences} \ + --log-level {params.log_level} \ + """ + -rule group_segments: +rule heuristic_group_segments: input: - script="scripts/group_segments.py", - metadata="results/metadata_post_prepare.ndjson", - sequences="results/sequences.ndjson", + script="scripts/heuristic_group_segments.py", + metadata="results/metadata_ungrouped.ndjson" if GROUPS_JSON else "results/metadata_post_prepare.ndjson", + sequences="results/sequences_ungrouped.ndjson" if GROUPS_JSON else "results/sequences.ndjson", + metadata_grouped="results/metadata_grouped.ndjson" if GROUPS_JSON else "results/metadata_post_prepare.ndjson", + sequences_grouped="results/sequences_grouped.ndjson" if GROUPS_JSON else "results/sequences.ndjson", config="results/config.yaml", output: metadata="results/metadata_post_group.ndjson", sequences="results/sequences_post_group.ndjson", params: log_level=LOG_LEVEL, + groups_json="true" if GROUPS_JSON else "false" shell: """ python {input.script} \ @@ -344,7 +388,11 @@ rule group_segments: --input-seq {input.sequences} \ --output-metadata {output.metadata} \ --output-seq {output.sequences} \ - --log-level {params.log_level} \ + --log-level {params.log_level} + if [ "{params.groups_json}" = "true" ]; then + cat {input.metadata_grouped} >> {output.metadata} + cat {input.sequences_grouped} >> {output.sequences} + fi """ diff --git a/ingest/scripts/deterministic_group_segments.py b/ingest/scripts/deterministic_group_segments.py new file mode 100644 index 000000000..8ef6eb29a --- /dev/null +++ b/ingest/scripts/deterministic_group_segments.py @@ -0,0 +1,259 @@ +"""Script to group segments together into sequence entries prior to submission to Loculus +given a json with known groups. + +Example output for a single isolate with 3 segments: +{"id": "KJ682796.1.L/KJ682809.1.M/KJ682819.1.S", +"metadata": { + "ncbiReleaseDate": "2014-07-06T00:00:00Z", + "ncbiSourceDb": "GenBank", + "authors": "D. Goedhals, F.J. Burt, J.T. Bester, R. Swanepoel", + "insdcVersion_L": "1", + "insdcVersion_M": "1", + "insdcVersion_S": "1", + "insdcAccessionFull_L": "KJ682796.1", + "insdcAccessionFull_M": "KJ682809.1", + "insdcAccessionFull_S": "KJ682819.1", + "hash_L": "ddbfc33d45267e9c1a08f8f5e76d3e39", + "hash_M": "f64777883ba9f5293257698255767f2c", + "hash_S": "f716ed13dca9c8a033d46da2f3dc2ff1", + "hash": "ce7056d0bd7e3d6d3eca38f56b9d10f8", + "submissionId": "KJ682796.1.L/KJ682809.1.M/KJ682819.1.S" +}}""" + +import hashlib +import json +import logging +import pathlib +from dataclasses import dataclass +from typing import Final + +import click +import orjsonl +import yaml + +logger = logging.getLogger(__name__) +logging.basicConfig( + encoding="utf-8", + level=logging.DEBUG, + format="%(asctime)s %(levelname)8s (%(filename)20s:%(lineno)4d) - %(message)s ", + datefmt="%H:%M:%S", +) + + +@dataclass(frozen=True) +class Config: + compound_country_field: str + fasta_id_field: str + insdc_segment_specific_fields: list[str] # What does this field mean? + nucleotide_sequences: list[str] + segmented: bool + + +# submissionId is actually NCBI accession +INTRINSICALLY_SEGMENT_SPECIFIC_FIELDS: Final = {"segment", "submissionId"} + + +def sort_authors(authors: str) -> str: + """Sort authors alphabetically""" + return "; ".join(sorted([author.strip() for author in authors.split(";")])) + + +def group_records(record_list, output_metadata, fasta_id_map, config, different_values_log={}): + # Assert that all records are from a different segment + for segment in config.nucleotide_sequences: + if len([record for record in record_list if record["metadata"]["segment"] == segment]) > 1: + # raise ValueError("Cannot group multiple records from the same segment" + ", ".join([record["id"] for record in record_list])) + logger.error( + "Cannot group multiple records from the same segment" + + ", ".join([record["id"] for record in record_list]) + ) + # write record list to a file + orjsonl.append("results/errors.ndjson", record_list) + return + segment_map = {record["metadata"]["segment"]: record["metadata"] for record in record_list} + + all_fields = sorted(record_list[0]["metadata"].keys()) + + # Metadata fields can vary between segments w/o indicating being from different assemblies + insdc_segment_specific_fields = set(config.insdc_segment_specific_fields) + insdc_segment_specific_fields.add("hash") + + # Fields that in principle should be identical for all segments of the same assembly + shared_fields = sorted( + set(all_fields) - insdc_segment_specific_fields - INTRINSICALLY_SEGMENT_SPECIFIC_FIELDS + ) + + grouped_metadata = {} + for key in shared_fields: + if key in {"authors", "authorAffiliations"}: + values = [sort_authors(d["metadata"][key]) for d in record_list] + else: + values = [d["metadata"][key] for d in record_list] + if len(set(values)) > 1: + different_values_log[key] = different_values_log.get(key, 0) + 1 + if len(set(values)) == 2 and "" in set(values): + grouped_metadata[key] = next(iter(set(values) - {""})) + continue + if key in {"authors", "authorAffiliations"}: + grouped_metadata[key] = values[0] + continue + orjsonl.append( + "results/warnings.ndjson", + { + "accessions": [record["id"] for record in record_list], + "field": key, + "values": values, + }, + ) + grouped_metadata[key] = record_list[0]["metadata"][key] + for field in insdc_segment_specific_fields: + for segment in config.nucleotide_sequences: + grouped_metadata[f"{field}_{segment}"] = ( + segment_map[segment][field] if segment in segment_map else "" + ) + + joint_key = "/".join( + [ + f"{segment_map[segment]['insdcAccessionFull']}.{segment}" + for segment in config.nucleotide_sequences + if segment in segment_map + ] + ) + grouped_metadata["submissionId"] = joint_key + + # Hash of all metadata fields should be the same if + # 1. field is not in keys_to_keep and + # 2. field is in keys_to_keep but is "" or None + filtered_record = {k: str(v) for k, v in grouped_metadata.items() if v is not None and str(v)} + + grouped_metadata["hash"] = hashlib.md5( + json.dumps(filtered_record, sort_keys=True).encode(), usedforsecurity=False + ).hexdigest() + + for segment in segment_map: + accession = segment_map[segment]["insdcAccessionFull"] + fasta_id_map[accession] = f"{joint_key}_{segment}" + + orjsonl.append(output_metadata, {"id": joint_key, "metadata": grouped_metadata}) + + +@click.command() +@click.option("--config-file", required=True, type=click.Path(exists=True)) +@click.option("--groups", required=True, type=click.Path(exists=True)) +@click.option("--input-seq", required=True, type=click.Path(exists=True)) +@click.option("--input-metadata", required=True, type=click.Path(exists=True)) +@click.option("--output-seq", required=True, type=click.Path()) +@click.option("--output-metadata", required=True, type=click.Path()) +@click.option("--output-ungrouped-seq", required=True, type=click.Path()) +@click.option("--output-ungrouped-metadata", required=True, type=click.Path()) +@click.option( + "--log-level", + default="INFO", + type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), +) +def main( + config_file: str, + groups: str, + input_seq: str, + input_metadata: str, + output_seq: str, + output_metadata: str, + output_ungrouped_seq: str, + output_ungrouped_metadata: str, + log_level: str, +) -> None: + logger.setLevel(log_level) + logging.getLogger("requests").setLevel(logging.WARNING) + logging.getLogger("urllib3").setLevel(logging.WARNING) + + full_config = yaml.safe_load(pathlib.Path(config_file).read_text(encoding="utf-8")) + relevant_config = {key: full_config[key] for key in Config.__annotations__} + config = Config(**relevant_config) + logger.info(config) + + if not config.segmented: + raise ValueError({"ERROR: You are running a function that requires segmented data"}) + + logger.info(f"Reading metadata from {input_metadata}") + + source_of_truth = json.load(open(groups, encoding="utf-8")) + logger.info(f"Found {len(source_of_truth.keys())} source of truth groups") + accession_to_group = {} + for group, metadata in source_of_truth.items(): + for accession in metadata: + accession_to_group[accession] = group + + found_groups = {group: [] for group in source_of_truth} + # Map from original accession to the new concatenated accession + type Accession = str + type SubmissionId = str + fasta_id_map: dict[Accession, SubmissionId] = {} + ungrouped_accessions = set() + different_values_log = {} + count_total = 0 + count_ungrouped = 0 + for record in orjsonl.stream(input_metadata): + count_total += 1 + metadata = record["metadata"] + if metadata["insdcAccessionFull"] not in accession_to_group: + count_ungrouped += 1 + orjsonl.append( + output_ungrouped_metadata, {"id": record["id"], "metadata": record["metadata"]} + ) + ungrouped_accessions.add(record["id"]) + continue + group = accession_to_group[metadata["insdcAccessionFull"]] + found_groups[group].append(record) + if len(found_groups[group]) == len(set(source_of_truth[group])): + group_records( + found_groups[group], output_metadata, fasta_id_map, config, different_values_log + ) + del found_groups[group] + + logger.info(f"Found {count_total} records") + logger.info(f"Unable to group {count_ungrouped} records") + + # add found_groups without all segments in file + count_unfilled_groups = 0 + count_missing_tests = 0 + for name, records in found_groups.items(): + count_unfilled_groups += 1 + logger.debug( + f"{name}: Missing record {set(source_of_truth[name]) - {record['metadata']['insdcAccessionFull'] for record in records}}" + ) + if len(records) == 0: + count_missing_tests += 1 + continue + group_records(records, output_metadata, fasta_id_map, config, different_values_log) + logger.info(different_values_log) + logger.info(f"Found {count_unfilled_groups} groups without all segments") + logger.info(f"Found {count_missing_tests} groups without any segments") + + count_grouped = 0 + count_ungrouped = 0 + count_ignored = 0 + for record in orjsonl.stream(input_seq): + accession = record["id"] + raw_sequence = record["sequence"] + if accession in ungrouped_accessions: + orjsonl.append(output_ungrouped_seq, {"id": accession, "sequence": raw_sequence}) + count_ungrouped += 1 + continue + if accession not in fasta_id_map: + count_ignored += 1 + continue + orjsonl.append( + output_seq, + { + "id": fasta_id_map[accession], + "sequence": raw_sequence, + }, + ) + count_grouped += 1 + logger.info(f"Wrote {count_grouped} grouped sequences") + logger.info(f"Wrote {count_ungrouped} ungrouped sequences") + logger.info(f"Ignored {count_ignored} sequences as not found in {input_seq}") + + +if __name__ == "__main__": + main() diff --git a/ingest/scripts/group_segments.py b/ingest/scripts/heuristic_group_segments.py similarity index 99% rename from ingest/scripts/group_segments.py rename to ingest/scripts/heuristic_group_segments.py index ddc1499b8..b9628bdd7 100644 --- a/ingest/scripts/group_segments.py +++ b/ingest/scripts/heuristic_group_segments.py @@ -24,7 +24,6 @@ import pathlib from collections import defaultdict from dataclasses import dataclass -from pathlib import Path from typing import Final import click diff --git a/ingest/tests/test_ingest.py b/ingest/tests/test_ingest.py index fe99b82e6..5c1a6b3ad 100644 --- a/ingest/tests/test_ingest.py +++ b/ingest/tests/test_ingest.py @@ -85,7 +85,7 @@ def test_snakemake(): copy_files(source_directory, destination_directory) run_snakemake("extract_ncbi_dataset_sequences", touch=True) # Ignore sequences for now run_snakemake("get_loculus_depositions", touch=True) # Do not call_loculus - run_snakemake("group_segments") + run_snakemake("heuristic_group_segments") run_snakemake("get_previous_submissions", touch=True) # Do not call_loculus run_snakemake("compare_hashes") run_snakemake("prepare_files") diff --git a/kubernetes/loculus/values.yaml b/kubernetes/loculus/values.yaml index a1e16c28d..b9fad0d66 100644 --- a/kubernetes/loculus/values.yaml +++ b/kubernetes/loculus/values.yaml @@ -1485,6 +1485,7 @@ defaultOrganisms: taxon_id: 3052518 nextclade_dataset_server: https://raw.githubusercontent.com/nextstrain/nextclade_data/cornelius-cchfv/data_output nextclade_dataset_name: nextstrain/cchfv/linked + grouping_ground_truth: https://anna-parker.github.io/influenza-a-groupings/results/cchf_groups.json enaDeposition: configFile: taxon_id: 3052518