Skip to content

Commit

Permalink
Merge pull request #29 from HKU-BAL/add_cnv
Browse files Browse the repository at this point in the history
updated to v0.2.1
  • Loading branch information
zhengzhenxian authored Jul 5, 2024
2 parents c3638f3 + fbefe3e commit bb9ae6c
Show file tree
Hide file tree
Showing 22 changed files with 4,511 additions and 22 deletions.
8 changes: 8 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,23 @@ RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86
ENV PATH /opt/conda/envs/clairs/bin:$PATH
ENV CONDA_DEFAULT_ENV clairs

RUN apt install curl zlib1g-dev libbz2-dev liblzma-dev libcurl4-openssl-dev -y && \
/opt/conda/bin/python3 -m pip install scipy scikit-learn && \
rm -rf /var/lib/apt/lists/*

COPY . .

RUN /bin/bash -c "source activate clairs" && cd /opt/bin/src/realign && \
g++ -std=c++14 -O1 -shared -fPIC -o realigner ssw_cpp.cpp ssw.c realigner.cpp && \
g++ -std=c++11 -shared -fPIC -o debruijn_graph -O3 debruijn_graph.cpp && \
cd /opt/bin/src/verdict/allele_counter && chmod +x setup.sh && /bin/bash setup.sh /opt/bin/src/verdict/allele_counter && \
wget http://www.bio8.cs.hku.hk/clairs/models/clairs_models.tar.gz -P /opt/models && \
mkdir -p /opt/conda/envs/clairs/bin/clairs_models && \
tar -zxvf /opt/models/clairs_models.tar.gz -C /opt/conda/envs/clairs/bin/clairs_models && \
rm /opt/models/clairs_models.tar.gz && \
mkdir -p /opt/conda/envs/clairs/bin/cnv_data && \
wget http://www.bio8.cs.hku.hk/clairs/data/reference_files.tar.gz -P /opt/cnv_data && \
tar -zxvf /opt/cnv_data/reference_files.tar.gz -C /opt/conda/envs/clairs/bin/cnv_data && rm -rf /opt/cnv_data/reference_files.tar.gz && \
echo 'will cite' | parallel --citation || true \
echo "source activate clairs" > ~/.bashrc

1 change: 1 addition & 0 deletions clairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
'clair3_somatic_calling',
'cal_metrics_in_af_range',
'concat_files',
'cnv_germline_tagging',
]


Expand Down
47 changes: 35 additions & 12 deletions clairs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import Dataset, Subset

import threading
import time
Expand Down Expand Up @@ -73,7 +73,7 @@


class BinFileDataset(Dataset):
def __init__(self, file_list, file_path, chunk_size, batch_size, debug_mode=False, discard_germline = False, add_af_in_label = False, smoothing = None):
def __init__(self, file_list, file_path, chunk_size, batch_size, debug_mode=False, discard_germline = False, add_af_in_label = False, smoothing = None, pileup=True):
### Configurations
self.debug_mode = debug_mode
self.discard_germline = discard_germline
Expand All @@ -82,12 +82,14 @@ def __init__(self, file_list, file_path, chunk_size, batch_size, debug_mode=Fals
self.chunk_size = chunk_size
self.batch_size = batch_size
self.add_af_in_label = add_af_in_label
self.random_start_position = None
self.train_flag = True

### Dataset Initialization
self.table_dataset_list, self.chunk_offset = self._populate_dataset_table(file_list, file_path)
self.cum_sum = np.cumsum(self.chunk_offset)
self.total_chunks = sum(self.chunk_offset)

self.pileup = pileup
### Smoothing
self.positive = 1 - smoothing if smoothing is not None else 1
self.negative = smoothing if smoothing is not None else 0
Expand All @@ -108,8 +110,13 @@ def __len__(self):
def __getitem__(self, idx):
bin_idx, chunk_idx = self._get_file_and_chunk_index(idx)
start_idx = chunk_idx * self.chunk_size
start_idx += self.random_start_position if self.random_start_position is not None and self.train_flag else 0
end_idx = start_idx + self.chunk_size
num_rows = len(self.table_dataset_list[bin_idx].root.input_matrix)
if end_idx > num_rows:
end_idx = num_rows
if start_idx >= num_rows:
start_idx = num_rows - 1
assert end_idx <= num_rows, f"Index out of range: {end_idx} > {num_rows}"
current_tensor = self.table_dataset_list[bin_idx].root.input_matrix[start_idx:end_idx]
current_label = self.table_dataset_list[bin_idx].root.label[start_idx:end_idx]
Expand All @@ -124,7 +131,10 @@ def __getitem__(self, idx):
current_label = [[self.negative, self.negative, self.positive] if np.argmax(item[:3]) == 2 else \
[self.negative, self.positive, self.negative] if np.argmax(item[:3]) == 1 else \
[self.positive, self.negative, self.negative]for item in current_label]
current_label = np.array(current_label)

if not self.pileup:
current_tensor = np.transpose(current_tensor, (0, 3, 1, 2))
current_label = np.array(current_label,dtype=np.float32)
if self.debug_mode:
position_info = self.table_dataset_list[bin_idx].root.position[start_idx:end_idx]
normal_info = self.table_dataset_list[bin_idx].root.normal_alt_info[start_idx:end_idx]
Expand Down Expand Up @@ -353,25 +363,31 @@ def train_model_torch_dataset(args):
if validation_fn:
val_list = os.listdir(validation_fn)
logging.info("[INFO] total {} validation bin files: {}".format(len(val_list), ','.join(val_list)))
train_dataset = BinFileDataset(bin_list, args.bin_fn, chunk_size, batch_size, debug_mode = False, discard_germline = discard_germline, \
add_af_in_label = param.add_af_in_label, smoothing=smoothing)
train_dataset = BinFileDataset(bin_list, args.bin_fn, chunk_size, batch_size, debug_mode=False, discard_germline = discard_germline, \
add_af_in_label = param.add_af_in_label, smoothing=smoothing, pileup=args.pileup)
train_chunk_num = len(train_dataset)

val_dataset = BinFileDataset(val_list, validation_fn, chunk_size, batch_size, debug_mode=debug_mode, discard_germline=discard_germline, \
add_af_in_label=False, smoothing=smoothing)
add_af_in_label=False, smoothing=smoothing, pileup=args.pileup)
validate_chunk_num = len(val_dataset)
total_chunks = train_chunk_num + validate_chunk_num
else:
total_dataset = BinFileDataset(bin_list, args.bin_fn, chunk_size, batch_size, debug_mode= debug_mode, discard_germline=discard_germline, \
add_af_in_label = param.add_af_in_label, smoothing = smoothing)
total_dataset = BinFileDataset(bin_list, args.bin_fn, chunk_size, batch_size, debug_mode=debug_mode, discard_germline=discard_germline, \
add_af_in_label = param.add_af_in_label, smoothing=smoothing, pileup=args.pileup)
total_chunks = len(total_dataset)
training_dataset_percentage = param.trainingDatasetPercentage if add_validation_dataset else None
if add_validation_dataset:
total_batches = total_chunks // chunks_per_batch
validate_chunk_num = int(
max(1., np.floor(total_batches * (1 - training_dataset_percentage))) * chunks_per_batch)
train_chunk_num = int(total_chunks - validate_chunk_num)
train_dataset, val_dataset = random_split(total_dataset, [train_chunk_num, validate_chunk_num])

train_indices = list(range(train_chunk_num))
val_indices = list(range(train_chunk_num, train_chunk_num + validate_chunk_num))

train_dataset = Subset(total_dataset, train_indices)
val_dataset = Subset(total_dataset, val_indices)

#set the training dataset to:no debug mode
train_dataset.dataset.debug_mode = False
val_dataset.dataset.add_af_in_label = False
Expand Down Expand Up @@ -438,6 +454,11 @@ def train_model_torch_dataset(args):
for epoch in range(1, max_epoch + 1):
epoch_loss = 0
fp, tp, fn = 0, 0, 0

#set a random start position for each epoch training
np.random.seed(epoch)
train_dataset.dataset.random_start_position = np.random.randint(0, chunk_size)
train_dataset.dataset.train_flag = True
t = tqdm(enumerate(train_dataloader), total=train_steps, position=0, leave=True)
v = tqdm(enumerate(validate_dataloader), total=validate_steps, position=0,
leave=True) if not debug_mode else enumerate(validate_dataloader)
Expand Down Expand Up @@ -498,6 +519,8 @@ def train_model_torch_dataset(args):
val_fp, val_tp, val_fn = 0, 0, 0
val_epoch_loss = 0
model.eval()
val_dataset.dataset.train_flag = False
val_dataset.dataset.random_start_position = None
for batch_idx, batch_tuple in v:
data, label, position_info, normal_info, tumor_info = None, None, None, None, None
if not debug_mode:
Expand Down Expand Up @@ -956,10 +979,10 @@ def main():
parser.add_argument('--exclude_training_samples', type=str, default=None,
help="Define training samples to be excluded")

parser.add_argument('--torch_dataset_num_workers', type=int, default=6,
parser.add_argument('--torch_dataset_num_workers', type=int, default=12,
help="Threads for torch dataset to preload datasets")

parser.add_argument('--torch_dataset_prefetch_factor', type=int, default=10,
parser.add_argument('--torch_dataset_prefetch_factor', type=int, default=12,
help="Prefetch factor for torch dataset to preload datasets")

# mutually-incompatible validation options
Expand Down
63 changes: 59 additions & 4 deletions run_clairs
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,22 @@ def check_args(args):
if args.use_longphase_for_intermediate_haplotagging is None:
args.use_longphase_for_intermediate_haplotagging = True

if args.enable_cnv_germline_tagging:
logging(log_warning(
"[WARNING] The --enable_cnv_germline_tagging option currently only works for GRCh38 reference genome!"))
if args.cnv_resource_dir is None:
args.cnv_resource_dir = os.path.join(args.conda_prefix, 'bin', 'cnv_data', 'reference_files')
if args.allele_counter_dir is None:
args.allele_counter_dir = os.path.join(file_directory, 'src', 'verdict', 'allele_counter')
if not os.path.exists(args.allele_counter_dir):
args.enable_cnv_germline_tagging = False
logging(log_warning(
"[WARNING] The allele counter {}is not found, disable the --enable_cnv_germline_tagging option!".format(args.allele_counter_dir)))
if not os.path.exists(args.cnv_resource_dir):
args.enable_cnv_germline_tagging = False
logging(log_warning(
"[WARNING] The CNV resource directory {} is not found, disable the --enable_cnv_germline_tagging option!".format(args.cnv_resource_dir)))

if args.genotyping_mode_vcf_fn is not None or args.hybrid_mode_vcf_fn is not None:
logging(log_warning("[INFO] Enable --print_ref_calls and --print_germline_calls options in genotyping mode!"))
args.print_ref_calls = True
Expand Down Expand Up @@ -805,10 +821,11 @@ def print_command_line(args):
cmdline += '--clair3_snp_min_af {} '.format(args.clair3_snp_min_af) if args.clair3_snp_min_af is not None else ""
cmdline += '--clair3_indel_min_af {} '.format(args.clair3_indel_min_af) if args.clair3_indel_min_af is not None else ""
cmdline += '--enable_clair3_germline_output ' if args.enable_clair3_germline_output else ""
cmdline += '--use_heterozygous_snp_in_normal_sample_for_intermediate_phasing {}'.format(args.use_heterozygous_snp_in_normal_sample_for_intermediate_phasing) if args.use_heterozygous_snp_in_normal_sample_for_intermediate_phasing is not None else ""
cmdline += '--use_heterozygous_snp_in_tumor_sample_for_intermediate_phasing {}'.format(args.use_heterozygous_snp_in_tumor_sample_for_intermediate_phasing) if args.use_heterozygous_snp_in_tumor_sample_for_intermediate_phasing is not None else ""
cmdline += '--use_heterozygous_indel_for_intermediate_phasing {}'.format(args.use_heterozygous_indel_for_intermediate_phasing) if args.use_heterozygous_indel_for_intermediate_phasing is not None else ""
cmdline += '--use_longphase_for_intermediate_haplotagging {}'.format(args.use_longphase_for_intermediate_haplotagging) if args.use_longphase_for_intermediate_haplotagging is not None else ""
cmdline += '--use_heterozygous_snp_in_normal_sample_for_intermediate_phasing {} '.format(args.use_heterozygous_snp_in_normal_sample_for_intermediate_phasing) if args.use_heterozygous_snp_in_normal_sample_for_intermediate_phasing is not None else ""
cmdline += '--use_heterozygous_snp_in_tumor_sample_for_intermediate_phasing {} '.format(args.use_heterozygous_snp_in_tumor_sample_for_intermediate_phasing) if args.use_heterozygous_snp_in_tumor_sample_for_intermediate_phasing is not None else ""
cmdline += '--use_heterozygous_indel_for_intermediate_phasing {} '.format(args.use_heterozygous_indel_for_intermediate_phasing) if args.use_heterozygous_indel_for_intermediate_phasing is not None else ""
cmdline += '--use_longphase_for_intermediate_haplotagging {} '.format(args.use_longphase_for_intermediate_haplotagging) if args.use_longphase_for_intermediate_haplotagging is not None else ""
cmdline += '--enable_cnv_germline_tagging ' if args.enable_cnv_germline_tagging else ""
cmdline += '--conda_prefix {} '.format(args.conda_prefix) if args.conda_prefix is not None else ""
args.cmdline = cmdline
except:
Expand Down Expand Up @@ -1228,6 +1245,24 @@ def somatic_calling(args):
genotyping_command += ' 2>&1 | tee ' + args.output_dir + '/logs/6_GT.log'
commands_list += [genotyping_command]

if args.enable_cnv_germline_tagging:
echo_list.append("[INFO] Add CNV germline tagging to output VCF")
cnv_germline_tagging_command = args.python + ' ' + main_entry + ' cnv_germline_tagging'
cnv_germline_tagging_command += ' --tumor_bam_fn ' + args.tumor_bam_fn
cnv_germline_tagging_command += ' --normal_bam_fn ' + args.normal_bam_fn
cnv_germline_tagging_command += ' --input_vcf_fn ' + args.output_dir + '/{}.vcf'.format(args.output_prefix)
cnv_germline_tagging_command += ' --allele_counter ' + str(args.allele_counter_dir)
cnv_germline_tagging_command += ' --cnv_resource_dir ' + str(args.cnv_resource_dir)
cnv_germline_tagging_command += ' --output_fn ' + args.output_dir + '/{}_cnv_germline_tagged.vcf.gz'.format(args.output_prefix)
cnv_germline_tagging_command += ' --output_dir ' + args.output_dir + '/tmp/cnv_output'
cnv_germline_tagging_command += ' --parallel ' + args.parallel
cnv_germline_tagging_command += ' --python ' + args.python
cnv_germline_tagging_command += ' --contig_fn ' + args.output_dir + '/tmp/CONTIGS'
cnv_germline_tagging_command += ' --threads ' + str(args.threads)
cnv_germline_tagging_command += ' 2>&1 | tee ' + args.output_dir + '/logs/7_CGT.log'
commands_list += [cnv_germline_tagging_command]


if args.enable_indel_calling:
##STEP 2: CREATE PAIR TENSOR
echo_list.append("[INFO] STEP 6: Indel Pileup Model Calling\n")
Expand Down Expand Up @@ -1675,6 +1710,12 @@ def somatic_parser():
help="EXPERIMENTAL: Use Clair3 default calling settings than Clair3 fast calling setting for tumor and normal germline varaint calling. The calling time would increase ~40 percent, Default: disabled"
)

optional_params.add_argument(
"--enable_cnv_germline_tagging",
action='store_true',
help="EXPERIMENTAL: Use Verdict to tag the germline variant in CNV regions. We suggest using the parameter only for sample with tumor purity lower than 0.8, Default: disabled"
)

ont_params.add_argument(
"--indel_output_prefix",
type=str,
Expand Down Expand Up @@ -1902,6 +1943,20 @@ def somatic_parser():
help=SUPPRESS
)

optional_params.add_argument(
"--cnv_resource_dir",
type=str,
default=None,
help=SUPPRESS
)

optional_params.add_argument(
"--allele_counter_dir",
type=str,
default=None,
help=SUPPRESS
)

optional_params.add_argument(
"--skip_steps",
type=str,
Expand Down
11 changes: 5 additions & 6 deletions shared/param.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# parameters
caller_name = "clairs"
version = "0.2.0"
version = "0.2.1"

from itertools import accumulate

Expand Down Expand Up @@ -87,10 +87,6 @@
'ont': ont_input_shape,
'hifi': hifi_input_shape}

upper_beta = 6
lower_beta = 4
upper_beta_liqud = 5

# Training hyper parameters
use_alt_base = True
label_shape = [3]
Expand Down Expand Up @@ -121,6 +117,9 @@
TUMOR_PREFIX = 't'
variant_type = {'ref', 'homo_somatic', 'homo_germline', 'hetero_germline'}
grad_norm_clip = 1.0
upper_beta = 6
lower_beta = 4
upper_beta_liqud = 5

use_beta_subsampling = True
use_exp_subsampling = False
Expand All @@ -138,4 +137,4 @@
0.98956, 0.99107, 0.99239, 0.99355, 0.99457, 0.99545, 0.99620, 0.99685, 0.99740, 0.99786,
0.99824, 0.99855, 0.99881, 0.99902, 0.99918, 0.99931, 0.99941, 0.99949, 0.99954, 0.99958,
0.99961, 0.99963, 0.99964, 0.99964, 0.99965, 0.99965, 0.99965, 0.99965, 0.99965, 1.00000,
]
]
Loading

0 comments on commit bb9ae6c

Please sign in to comment.