Skip to content

Commit

Permalink
fix(ingest): switch to ndjson instead of json to stream grouped metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-parker committed Jan 30, 2025
1 parent 9aa28ab commit 72dd94a
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 220 deletions.
18 changes: 9 additions & 9 deletions ingest/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ rule prepare_metadata:
sequence_hashes="results/sequence_hashes.ndjson",
config="results/config.yaml",
output:
metadata="results/metadata_post_prepare.json",
metadata="results/metadata_post_prepare.ndjson",
params:
log_level=LOG_LEVEL,
shell:
Expand All @@ -328,11 +328,11 @@ rule prepare_metadata:
rule group_segments:
input:
script="scripts/group_segments.py",
metadata="results/metadata_post_prepare.json",
metadata="results/metadata_post_prepare.ndjson",
sequences="results/sequences.ndjson",
config="results/config.yaml",
output:
metadata="results/metadata_post_group.json",
metadata="results/metadata_post_group.ndjson",
sequences="results/sequences_post_group.ndjson",
params:
log_level=LOG_LEVEL,
Expand Down Expand Up @@ -368,9 +368,9 @@ rule get_previous_submissions:
# By delaying the start of the script
script="scripts/call_loculus.py",
prepped_metadata=(
"results/metadata_post_group.json"
"results/metadata_post_group.ndjson"
if SEGMENTED
else "results/metadata_post_prepare.json"
else "results/metadata_post_prepare.ndjson"
),
config="results/config.yaml",
output:
Expand All @@ -395,9 +395,9 @@ rule compare_hashes:
config="results/config.yaml",
old_hashes="results/previous_submissions.json",
metadata=(
"results/metadata_post_group.json"
"results/metadata_post_group.ndjson"
if SEGMENTED
else "results/metadata_post_prepare.json"
else "results/metadata_post_prepare.ndjson"
),
output:
to_submit="results/to_submit.json",
Expand Down Expand Up @@ -431,9 +431,9 @@ rule prepare_files:
script="scripts/prepare_files.py",
config="results/config.yaml",
metadata=(
"results/metadata_post_group.json"
"results/metadata_post_group.ndjson"
if SEGMENTED
else "results/metadata_post_prepare.json"
else "results/metadata_post_prepare.ndjson"
),
sequences=(
"results/sequences_post_group.ndjson"
Expand Down
2 changes: 0 additions & 2 deletions ingest/scripts/call_loculus.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,6 @@ def get_submitted(config: Config):
logger.info(f"Backend has status of: {len(statuses)} sequence entries from ingest")
logger.info(f"Ingest has submitted: {len(entries)} sequence entries to ingest")

logger.debug(entries)
logger.debug(statuses)
for entry in entries:
loculus_accession = entry["accession"]
submitter = entry["submitter"]
Expand Down
6 changes: 4 additions & 2 deletions ingest/scripts/compare_hashes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any

import click
import orjsonl
import requests
import yaml

Expand Down Expand Up @@ -171,7 +172,6 @@ def main(
config.debug_hashes = True

submitted: dict = json.load(open(old_hashes, encoding="utf-8"))
new_metadata = json.load(open(metadata, encoding="utf-8"))

update_manager = SequenceUpdateManager(
submit=[],
Expand All @@ -184,7 +184,9 @@ def main(
config=config,
)

for fasta_id, record in new_metadata.items():
for field in orjsonl.stream(metadata):
fasta_id = field["id"]
record = field["metadata"]
if not config.segmented:
insdc_accession_base = record["insdcAccessionBase"]
if not insdc_accession_base:
Expand Down
33 changes: 17 additions & 16 deletions ingest/scripts/group_segments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Script to group segments together into sequence entries prior to submission to Loculus
Example output for a single isolate with 3 segments:
"KJ682796.1.L/KJ682809.1.M/KJ682819.1.S": {
Example ndjson output for a single isolate with 3 segments:
{"id": "KJ682796.1.L/KJ682809.1.M/KJ682819.1.S",
"metadata": {
"ncbiReleaseDate": "2014-07-06T00:00:00Z",
"ncbiSourceDb": "GenBank",
"authors": "D. Goedhals, F.J. Burt, J.T. Bester, R. Swanepoel",
Expand All @@ -15,7 +16,7 @@
"hash_S": "f716ed13dca9c8a033d46da2f3dc2ff1",
"hash": "ce7056d0bd7e3d6d3eca38f56b9d10f8",
"submissionId": "KJ682796.1.L/KJ682809.1.M/KJ682819.1.S"
},"""
}}"""

import hashlib
import json
Expand Down Expand Up @@ -100,9 +101,11 @@ def main(
segments = config.nucleotide_sequences
number_of_segments = len(segments)

with open(input_metadata, encoding="utf-8") as file:
segment_metadata: dict[str, dict[str, str]] = json.load(file)
number_of_segmented_records = len(segment_metadata.keys())
number_of_segmented_records = 0
segment_metadata: dict[str, dict[str, str]] = {}
for record in orjsonl.stream(input_metadata):
segment_metadata[record["id"]] = record["metadata"]
number_of_segmented_records += 1
logger.info(f"Found {number_of_segmented_records} individual segments in metadata file")

# Group segments according to isolate, collection date and isolate specific values
Expand Down Expand Up @@ -174,7 +177,7 @@ def main(
number_of_groups = len(grouped_accessions)
group_lower_bound = number_of_segmented_records // number_of_segments
group_upper_bound = number_of_segmented_records
logging.info(f"Total of {number_of_groups} groups left after merging")
logger.info(f"Total of {number_of_groups} groups left after merging")
if number_of_groups < group_lower_bound:
raise ValueError(
{
Expand All @@ -192,11 +195,11 @@ def main(
}
)

# Add segment specific metadata for the segments
metadata: dict[str, dict[str, str]] = {}
# Map from original accession to the new concatenated accession
fasta_id_map: dict[Accession, Accession] = {}

count = 0

for group in grouped_accessions:
# Create key by concatenating all accession numbers with their segments
# e.g. AF1234_S/AF1235_M/AF1236_L
Expand Down Expand Up @@ -241,12 +244,10 @@ def main(
json.dumps(filtered_record, sort_keys=True).encode(), usedforsecurity=False
).hexdigest()

metadata[joint_key] = row
orjsonl.append(output_metadata, {"id": joint_key, "metadata": row})
count += 1

Path(output_metadata).write_text(
json.dumps(metadata, indent=4, sort_keys=True), encoding="utf-8"
)
logging.info(f"Wrote grouped metadata for {len(metadata)} sequences")
logger.info(f"Wrote grouped metadata for {count} sequences")

count = 0
count_ignored = 0
Expand All @@ -265,8 +266,8 @@ def main(
},
)
count += 1
logging.info(f"Wrote {count} sequences")
logging.info(f"Ignored {count_ignored} sequences as not found in {input_seq}")
logger.info(f"Wrote {count} sequences")
logger.info(f"Ignored {count_ignored} sequences as not found in {input_seq}")


if __name__ == "__main__":
Expand Down
85 changes: 51 additions & 34 deletions ingest/scripts/prepare_files.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import csv
import json
import logging
import os
import sys
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -101,65 +102,81 @@ def main(
relevant_config = {key: full_config[key] for key in Config.__annotations__}
config = Config(**relevant_config)

metadata = json.load(open(metadata_path, encoding="utf-8"))
to_submit = json.load(open(to_submit_path, encoding="utf-8"))
to_revise = json.load(open(to_revise_path, encoding="utf-8"))
to_revoke = json.load(open(to_revoke_path, encoding="utf-8"))

metadata_submit = []
metadata_revise = []
metadata_submit_prior_to_revoke = [] # Only for multi-segmented case, sequences are revoked
# due to grouping changes and the newly grouped segments must be submitted as new sequences
submit_ids = set()
revise_ids = set()
submit_prior_to_revoke_ids = set()

for fasta_id in to_submit:
metadata_submit.append(metadata[fasta_id])
submit_ids.update(ids_to_add(fasta_id, config))
def write_to_tsv_stream(data, filename, columns_list=None):
# Check if the file exists
file_exists = os.path.exists(filename)

for fasta_id, loculus_accession in to_revise.items():
revise_record = metadata[fasta_id]
revise_record["accession"] = loculus_accession
metadata_revise.append(revise_record)
revise_ids.update(ids_to_add(fasta_id, config))
with open(filename, "a", newline="", encoding="utf-8") as output_file:
keys = columns_list or data.keys()
dict_writer = csv.DictWriter(output_file, keys, delimiter="\t")

found_seq_to_revoke = False
for fasta_id in to_revoke:
metadata_submit_prior_to_revoke.append(metadata[fasta_id])
submit_prior_to_revoke_ids.update(ids_to_add(fasta_id, config))
# Write the header only if the file doesn't already exist
if not file_exists:
dict_writer.writeheader()

if found_seq_to_revoke:
revocation_notification(config, to_revoke)
dict_writer.writerow(data)

def write_to_tsv(data, filename):
if not data:
Path(filename).touch()
return
keys = data[0].keys()
with open(filename, "w", newline="", encoding="utf-8") as output_file:
dict_writer = csv.DictWriter(output_file, keys, delimiter="\t")
dict_writer.writeheader()
dict_writer.writerows(data)
columns_list = None
for field in orjsonl.stream(metadata_path):
fasta_id = field["id"]
record = field["metadata"]
if not columns_list:
columns_list = record.keys()

if fasta_id in to_submit:
write_to_tsv_stream(record, metadata_submit_path, columns_list)
submit_ids.update(ids_to_add(fasta_id, config))
continue

if fasta_id in to_revise:
record["accession"] = to_revise[fasta_id]
write_to_tsv_stream(record, metadata_revise_path, [*columns_list, "accession"])
revise_ids.update(ids_to_add(fasta_id, config))
continue

found_seq_to_revoke = False
if fasta_id in to_revoke:
submit_prior_to_revoke_ids.update(ids_to_add(fasta_id, config))
write_to_tsv_stream(record, metadata_submit_prior_to_revoke_path, columns_list)
found_seq_to_revoke = True

write_to_tsv(metadata_submit, metadata_submit_path)
write_to_tsv(metadata_revise, metadata_revise_path)
write_to_tsv(metadata_submit_prior_to_revoke, metadata_submit_prior_to_revoke_path)
if found_seq_to_revoke:
revocation_notification(config, to_revoke)

def stream_filter_to_fasta(input, output, keep):
def stream_filter_to_fasta(input, output, output_metadata, keep):
if len(keep) == 0:
Path(output).touch()
Path(output_metadata).touch()
return
with open(output, "w", encoding="utf-8") as output_file:
for record in orjsonl.stream(input):
if record["id"] in keep:
output_file.write(f">{record['id']}\n{record['sequence']}\n")

stream_filter_to_fasta(input=sequences_path, output=sequences_submit_path, keep=submit_ids)
stream_filter_to_fasta(input=sequences_path, output=sequences_revise_path, keep=revise_ids)
stream_filter_to_fasta(
input=sequences_path,
output=sequences_submit_path,
output_metadata=metadata_submit_path,
keep=submit_ids,
)
stream_filter_to_fasta(
input=sequences_path,
output=sequences_revise_path,
output_metadata=metadata_revise_path,
keep=revise_ids,
)
stream_filter_to_fasta(
input=sequences_path,
output=sequences_submit_prior_to_revoke_path,
output_metadata=metadata_submit_prior_to_revoke_path,
keep=submit_prior_to_revoke_ids,
)

Expand Down
7 changes: 2 additions & 5 deletions ingest/scripts/prepare_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import json
import logging
from dataclasses import dataclass
from pathlib import Path

import click
import orjsonl
Expand Down Expand Up @@ -143,11 +142,9 @@ def main(

record["hash"] = hashlib.md5(prehash.encode(), usedforsecurity=False).hexdigest()

meta_dict = {rec[fasta_id_field]: rec for rec in metadata}
orjsonl.append(output, {"id": record[fasta_id_field], "metadata": record})

Path(output).write_text(json.dumps(meta_dict, indent=4, sort_keys=True), encoding="utf-8")

logging.info(f"Saved metadata for {len(metadata)} sequences")
logger.info(f"Saved metadata for {len(metadata)} sequences")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 72dd94a

Please sign in to comment.