Skip to content

Commit

Permalink
Port of nvscorevariants into GATK, with a basic tool frontend (#8004)
Browse files Browse the repository at this point in the history
* Port of the NVIDIA-authored nvscorevariants tool into GATK, with a basic tool frontend

* This is a direct replacement for the legacy tool CNNScoreVariants. It produces results that are almost identical to that tool, but is implemented on top of a more modern ML library, Pytorch.

* The Python code is taken from https://github.com/NVIDIA-Genomics-Research/nvscorevariants, with a few minor modifications necessary to get the tool working on newer versions of the Python libraries.

* Added pytorch-lightning to the GATK conda environment, as it's required by this tool

* Disabled jacoco in build.gradle, as it was causing strange errors related to jacoco trying to parse the new Pytorch model files in the resources directory

---------

Co-authored-by: Louis Bergelson <[email protected]>
  • Loading branch information
droazen and lbergelson authored Oct 17, 2024
1 parent a070efc commit a377b07
Show file tree
Hide file tree
Showing 25 changed files with 2,979 additions and 32 deletions.
34 changes: 23 additions & 11 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ plugins {
id "application" // provides installDist
id 'maven-publish'
id 'signing'
id "jacoco"
// id "jacoco"
id "de.undercouch.download" version "5.4.0" //used for downloading GSA lib
id "com.github.johnrengelman.shadow" version "8.1.1" //used to build the shadow and sparkJars
id "com.github.ben-manes.versions" version "0.12.0" //used for identifying dependencies that need updating
Expand Down Expand Up @@ -625,17 +625,22 @@ task bundle(type: Zip) {
}
}

jacocoTestReport {
//jacocoTestReport {
// dependsOn test
//
// group = "Reporting"
// description = "Generate Jacoco coverage reports after running tests."
// getAdditionalSourceDirs().from(sourceSets.main.allJava.srcDirs)
//
// reports {
// xml.required = true
// html.required = true
// }
//}
//}

task jacocoTestReport {
dependsOn test

group = "Reporting"
description = "Generate Jacoco coverage reports after running tests."
getAdditionalSourceDirs().from(sourceSets.main.allJava.srcDirs)

reports {
xml.required = true
html.required = true
}
}

task condaStandardEnvironmentDefinition(type: Copy) {
Expand Down Expand Up @@ -687,6 +692,13 @@ task localDevCondaEnv(type: Exec) {
commandLine "conda", "env", "create", "--yes", "-f", gatkCondaYML
}

task localDevCondaUpdate(type: Exec) {
dependsOn 'condaEnvironmentDefinition'
inputs.file("$buildDir/$pythonPackageArchiveName")
workingDir "$buildDir"
commandLine "conda", "env", "update", "-f", gatkCondaYML
}

task javadocJar(type: Jar, dependsOn: javadoc) {
archiveClassifier = 'javadoc'
from "$docBuildDir/javadoc"
Expand Down
45 changes: 24 additions & 21 deletions scripts/docker/dockertest.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ buildscript {

plugins {
id "java" // set up default java compile and test tasks
id "jacoco"
// id "jacoco"
}

repositories {
Expand Down Expand Up @@ -113,9 +113,9 @@ def getJVMArgs(runtimeAddOpens, testAddOpens) {

test {
jvmArgs = getJVMArgs(runtimeAddOpens, testAddOpens)
jacoco {
jvmArgs = getJVMArgs(runtimeAddOpens, testAddOpens)
}
// jacoco {
// jvmArgs = getJVMArgs(runtimeAddOpens, testAddOpens)
// }
}

task testOnPackagedReleaseJar(type: Test){
Expand Down Expand Up @@ -153,22 +153,25 @@ task testOnPackagedReleaseJar(type: Test){

// Task intended to collect coverage data from testOnPackagedReleaseJar executed inside the docker image
// the classpath for these tests is set at execution time for testOnPackagedReleaseJar
task jacocoTestReportOnPackagedReleaseJar(type: JacocoReport) {
String sourceFiles = "$System.env.SOURCE_DIR"
String testClassesUnpacked = "$System.env.CP_DIR"

//task jacocoTestReportOnPackagedReleaseJar(type: JacocoReport) {
// String sourceFiles = "$System.env.SOURCE_DIR"
// String testClassesUnpacked = "$System.env.CP_DIR"
//
// dependsOn testOnPackagedReleaseJar
// executionData testOnPackagedReleaseJar
// additionalSourceDirs.setFrom(sourceSets.main.allJava.srcDirs)
//
// sourceDirectories.setFrom(sourceFiles)
// classDirectories.setFrom(testClassesUnpacked)
//
// group = "Reporting"
// description = "Generate Jacoco coverage reports after running tests inside the docker image."
//
// reports {
// xml.required = true
// html.required = true
// }
//}
task jacocoTestReportOnPackagedReleaseJar {
dependsOn testOnPackagedReleaseJar
executionData testOnPackagedReleaseJar
additionalSourceDirs.setFrom(sourceSets.main.allJava.srcDirs)

sourceDirectories.setFrom(sourceFiles)
classDirectories.setFrom(testClassesUnpacked)

group = "Reporting"
description = "Generate Jacoco coverage reports after running tests inside the docker image."

reports {
xml.required = true
html.required = true
}
}
2 changes: 2 additions & 0 deletions scripts/gatkcondaenv.yml.template
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ dependencies:
- conda-forge::scipy=1.11.4
- conda-forge::h5py=3.10.0
- conda-forge::pytorch=2.1.0=*mkl*100
- conda-forge::pytorch-lightning=2.4.0 # supports Pytorch >= 2.1 and <= 2.4, used by NVScoreVariants
- conda-forge::scikit-learn=1.3.2
- conda-forge::matplotlib=3.8.2
- conda-forge::pandas=2.1.3
- conda-forge::tqdm=4.66.1
- conda-forge::dill=0.3.7 # used for pickling lambdas in TrainVariantAnnotationsModel
- conda-forge::biopython=1.84 # used by NVScoreVariants

# core R dependencies; these should only be used for plotting and do not take precedence over core python dependencies!
- r-base=4.3.1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package org.broadinstitute.hellbender.tools.walkers.vqsr;

import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.argparser.ExperimentalFeature;
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.io.Resource;
import org.broadinstitute.hellbender.utils.python.PythonExecutorBase;
import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor;
import org.broadinstitute.hellbender.utils.runtime.ProcessOutput;
import picard.cmdline.programgroups.VariantFilteringProgramGroup;

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
* Annotate a VCF with scores from a PyTorch-based Convolutional Neural Network (CNN).
*
* It contains both a 1D model that uses only the reference sequence and variant annotations,
* and a 2D model that uses reads in addition to the reference sequence and variant annotations.
*
* The scores for each variant record will be placed in an INFO field annotation named CNN_1D
* (if using the 1D model) or CNN_2D (if using the 2D model). These scores represent the
* log odds of being a true variant versus being false under the trained convolutional neural
* network.
*
* The provided models were trained on short-read human sequencing data, and will likely not perform
* well for other kinds of sequencing data, or for non-human data. A companion training tool for
* NVScoreVariants will be released in the future to support users who need to train their own models.
*
* Example command for running with the 1D model:
*
* <pre>
* gatk NVScoreVariants \
* -V src/test/resources/large/VQSR/recalibrated_chr20_start.vcf \
* -R src/test/resources/large/human_g1k_v37.20.21.fasta \
* -O output.vcf
* </pre>
*
* Example command for running with the 2D model:
*
* <pre>
* gatk NVScoreVariants \
* -V src/test/resources/large/VQSR/recalibrated_chr20_start.vcf \
* -R src/test/resources/large/human_g1k_v37.20.21.fasta \
* --tensor-type read_tensor \
* -I src/test/resources/large/VQSR/g94982_contig_20_start_bamout.bam \
* -O output.vcf
* </pre>
*
* <b><i>The PyTorch Python code that this tool relies upon was contributed by engineers at
* <a href="https://github.com/NVIDIA-Genomics-Research">NVIDIA Genomics Research</a>.
* We would like to give particular thanks to Babak Zamirai of NVIDIA, who authored
* the tool, as well as to Ankit Sethia, Mehrzad Samadi, and George Vacek (also of NVIDIA),
* without whom this project would not have been possible.</i></b>
*/
@CommandLineProgramProperties(
summary = "Annotate a VCF with scores from a PyTorch-based Convolutional Neural Network (CNN)",
oneLineSummary = "Annotate a VCF with scores from a PyTorch-based Convolutional Neural Network (CNN)",
programGroup = VariantFilteringProgramGroup.class
)
@ExperimentalFeature
public class NVScoreVariants extends CommandLineProgram {

public static final String NV_SCORE_VARIANTS_PACKAGE = "scorevariants";
public static final String NV_SCORE_VARIANTS_SCRIPT = "nvscorevariants.py";
public static final String NV_SCORE_VARIANTS_1D_MODEL_FILENAME = "1d_cnn_mix_train_full_bn.pt";
public static final String NV_SCORE_VARIANTS_2D_MODEL_FILENAME = "small_2d.pt";
public static final String NV_SCORE_VARIANTS_1D_MODEL = Resource.LARGE_RUNTIME_RESOURCES_PATH + "/nvscorevariants/" + NV_SCORE_VARIANTS_1D_MODEL_FILENAME;
public static final String NV_SCORE_VARIANTS_2D_MODEL = Resource.LARGE_RUNTIME_RESOURCES_PATH + "/nvscorevariants/" + NV_SCORE_VARIANTS_2D_MODEL_FILENAME;

public enum TensorType {
reference,
read_tensor
}

@Argument(fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME, shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME, doc = "Output VCF file")
private File outputVCF;

@Argument(fullName = StandardArgumentDefinitions.VARIANT_LONG_NAME, shortName = StandardArgumentDefinitions.VARIANT_SHORT_NAME, doc = "Input VCF file containing variants to score")
private File inputVCF;

@Argument(fullName = StandardArgumentDefinitions.REFERENCE_LONG_NAME, shortName = StandardArgumentDefinitions.REFERENCE_SHORT_NAME, doc = "Reference sequence file")
private File reference;

@Argument(fullName = StandardArgumentDefinitions.INPUT_LONG_NAME, shortName = StandardArgumentDefinitions.INPUT_SHORT_NAME, doc = "BAM file containing reads, if using the 2D model", optional = true)
private File bam;

@Argument(fullName = "tensor-type", doc = "Name of the tensors to generate: reference for 1D reference tensors and read_tensor for 2D tensors.", optional = true)
private TensorType tensorType = TensorType.reference;

@Argument(fullName = "batch-size", doc = "Batch size", optional = true)
private int batchSize = 64;

@Argument(fullName = "random-seed", doc = "Seed to initialize the random number generator", optional = true)
private int randomSeed = 724;

@Argument(fullName = "tmp-file", doc = "The temporary VCF-like file where variants scores will be written", optional = true)
private File tmpFile;

@Argument(fullName = "accelerator", doc = "Type of hardware accelerator to use (auto, cpu, cuda, mps, tpu, etc)", optional = true)
private String accelerator = "auto";

@Override
protected void onStartup() {
PythonScriptExecutor.checkPythonEnvironmentForPackage(NV_SCORE_VARIANTS_PACKAGE);
}

@Override
protected Object doWork() {
final PythonScriptExecutor pythonExecutor = new PythonScriptExecutor(PythonExecutorBase.PythonExecutableName.PYTHON3, true);
final Resource pythonScriptResource = new Resource(NV_SCORE_VARIANTS_SCRIPT, NVScoreVariants.class);
final File extractedModelDirectory = extractModelFilesToTempDirectory();

if ( tmpFile == null ) {
tmpFile = IOUtils.createTempFile("NVScoreVariants_tmp", ".txt");
}

final List<String> arguments = new ArrayList<>(Arrays.asList(
"--output-file", outputVCF.getAbsolutePath(),
"--vcf-file", inputVCF.getAbsolutePath(),
"--ref-file", reference.getAbsolutePath(),
"--tensor-type", tensorType.name(),
"--batch-size", Integer.toString(batchSize),
"--seed", Integer.toString(randomSeed),
"--tmp-file", tmpFile.getAbsolutePath(),
"--model-directory", extractedModelDirectory.getAbsolutePath()
));

if (accelerator != null) {
arguments.addAll(List.of("--accelerator",accelerator));
}

if ( tensorType == TensorType.reference && bam != null ) {
throw new UserException.BadInput("--" + StandardArgumentDefinitions.INPUT_LONG_NAME +
" should only be specified when running with --tensor-type " + TensorType.read_tensor.name());
}
else if ( tensorType == TensorType.read_tensor && bam == null ) {
throw new UserException.BadInput("Need to specify a BAM file via --" + StandardArgumentDefinitions.INPUT_LONG_NAME +
" when running with --tensor-type " + TensorType.read_tensor.name());
}

if ( bam != null ) {
arguments.addAll(Arrays.asList("--input-file", bam.getAbsolutePath()));
}

logger.info("Running Python NVScoreVariants module with arguments: " + arguments);
final ProcessOutput pythonOutput = pythonExecutor.executeScriptAndGetOutput(
pythonScriptResource,
null,
arguments
);

if ( pythonOutput.getExitValue() != 0 ) {
logger.error("Error running NVScoreVariants Python command:\n" + pythonOutput.getStatusSummary(true));
}

return pythonOutput.getExitValue();
}

private File extractModelFilesToTempDirectory() {
final File extracted1DModel = IOUtils.writeTempResourceFromPath(NV_SCORE_VARIANTS_1D_MODEL, null);
final File extracted2DModel = IOUtils.writeTempResourceFromPath(NV_SCORE_VARIANTS_2D_MODEL, null);
final File modelDirectory = IOUtils.createTempDir("NVScoreVariants_models");

if ( ! extracted1DModel.renameTo(new File(modelDirectory, NV_SCORE_VARIANTS_1D_MODEL_FILENAME)) ) {
throw new UserException("Error moving " + extracted1DModel.getAbsolutePath() + " to " + modelDirectory.getAbsolutePath());
}
if ( ! extracted2DModel.renameTo(new File(modelDirectory, NV_SCORE_VARIANTS_2D_MODEL_FILENAME)) ) {
throw new UserException("Error moving " + extracted2DModel.getAbsolutePath() + " to " + modelDirectory.getAbsolutePath());
}

logger.info("Extracted models to: " + modelDirectory.getAbsolutePath());
return modelDirectory;
}

@Override
protected void onShutdown() {
super.onShutdown();
}
}
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/python3

from pysam import VariantFile
import re
import argparse
import sys

CONTIG_INDEX = 0;
POS_INDEX = 1;
REF_INDEX = 2;
ALT_INDEX = 3;
KEY_INDEX = 4;

def create_output_vcf(vcf_in, scores_file, vcf_out, label):
variant_file = VariantFile(vcf_in)
variant_file.reset()

variant_file.header.info.add(id=label, number=1, type='Float', description='Log odds of being a true variant versus \
being false under the trained Convolutional Neural Network')
header = variant_file.header.copy()
vcfWriter = VariantFile(vcf_out, 'w', header=header)

with open(scores_file) as scoredVariants:
sv = next(scoredVariants)
for variant in variant_file:
scoredVariant = sv.split('\t')
if variant.contig == scoredVariant[CONTIG_INDEX] and \
variant.pos == int(scoredVariant[POS_INDEX]) and \
variant.ref == scoredVariant[REF_INDEX] and \
', '.join(variant.alts or []) == re.sub('[\[\]]', '', scoredVariant[ALT_INDEX]):

if len(scoredVariant) > KEY_INDEX:
variant.info.update({label: float(scoredVariant[KEY_INDEX])})

vcfWriter.write(variant)

sv = next(scoredVariants, None)
else:
sys.exit("Score file out of sync with original VCF. Score file has: " + sv + "\nBut VCF has: " + str(variant))
Loading

0 comments on commit a377b07

Please sign in to comment.