Skip to content

Commit

Permalink
Do not load the models with global variables
Browse files Browse the repository at this point in the history
  • Loading branch information
PedroBarbosa committed Jun 12, 2024
1 parent f8203a0 commit 8205fbb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
6 changes: 2 additions & 4 deletions dress/datasetevaluation/representation/motifs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,7 @@ def _remove_self_contained(gr: pr.PyRanges, scan_method: str) -> pr.PyRanges:
df = pd.merge(df, contained_same_rbp, how="left", on=to_drop_cols).drop(
columns=to_clean_cols
)
df.fillna({'Has_self_submotif': False}, inplace=True)
df = df.infer_objects()
df['Has_self_submotif'] = df.Has_self_submotif.fillna(False).infer_objects(copy=False)

#######################
# Other RBP contained #
Expand All @@ -429,8 +428,7 @@ def _remove_self_contained(gr: pr.PyRanges, scan_method: str) -> pr.PyRanges:
df = pd.merge(df, contained_other_rbp, how="left", on=to_drop_cols).drop(
columns=to_clean_cols[:-1]
)
df.fillna({'Has_other_submotif': False}, inplace=True)
df = df.infer_objects()
df['Has_other_submotif'] = df.Has_other_submotif.fillna(False).infer_objects(copy=False)
# logger.debug(".. {} hits flagged ..".format(contained_other_rbp.shape[0]))

return pr.PyRanges(df)
Expand Down
33 changes: 20 additions & 13 deletions dress/datasetgeneration/black_box/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from keras.utils import pad_sequences # noqa: E402
from spliceai.utils import one_hot_encode # noqa: E402
from .singleton_model import ( # noqa: E402
batch_function_spliceAI, # noqa: E402
batch_function_spliceai, # noqa: E402
batch_function_pangolin, # noqa: E402
) # noqa: E402

Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(
):
"""SpliceAI model class"""
super().__init__(context, batch_size, scoring_metric)
self._init_model()

def run(
self, seqs: List[str], original_seq: bool = False
Expand All @@ -95,14 +96,11 @@ def run(
preds = []
batches, seq_lengths = self.data_preparation(seqs, "SpliceAI")
max_len = max(seq_lengths)
global predict_batch_spliceai
if predict_batch_spliceai is None:
predict_batch_spliceai = batch_function_spliceAI()

for _i, batch in enumerate(tqdm(batches)):
n_seqs = self.batch_size * _i
batch_tf = tf.convert_to_tensor(batch, dtype=tf.int32)
raw_preds = predict_batch_spliceai(batch_tf)
raw_preds = self.predict_batch_spliceai(batch_tf)
batch_preds = [
x.numpy()[max_len - seq_lengths[i + n_seqs] :]
for i, x in enumerate(raw_preds)
Expand Down Expand Up @@ -143,6 +141,9 @@ def get_exon_score(

return out

def _init_model(self):
self.predict_batch_spliceai = batch_function_spliceai()


class Pangolin(DeepLearningModel):
def __init__(
Expand All @@ -162,6 +163,11 @@ def __init__(
self.model_nums = [1, 3, 5, 7]
elif mode == "ss_probability":
self.model_nums = [0, 2, 4, 6]
else:
logger.error(
"Invalid Pangolin mode. Please choose from ss_usage, ss_probability"
)
exit(1)

if tissue:
t_map_idx = {"heart": 0, "liver": 1, "brain": 2, "testis": 3}
Expand All @@ -176,6 +182,7 @@ def __init__(
"Invalid tissue type. Please choose from heart, liver, brain, testis"
)
exit(1)
self._init_model()

def run(
self, seqs: List[str], original_seq: bool = False
Expand All @@ -192,18 +199,15 @@ def run(
preds = []
batches, seq_lengths = self.data_preparation(seqs, "Pangolin")
max_len = max(seq_lengths)
global predict_batch_pangolin
if predict_batch_pangolin is None:
predict_batch_pangolin = batch_function_pangolin(self.model_nums)

for _i, seq in enumerate(tqdm(batches)):
for _i, seqs in enumerate(tqdm(batches)):
n_seqs = self.batch_size * _i
seq = seq.transpose(0, 2, 1)
seq = torch.from_numpy(seq).float()
seqs = seqs.transpose(0, 2, 1)
batch = torch.from_numpy(seqs).float()
if torch.cuda.is_available():
seq = seq.to(torch.device("cuda"))
batch = batch.to(torch.device("cuda"))

raw_preds = predict_batch_pangolin(seq)
raw_preds = self.predict_batch_pangolin(batch)
batch_preds = [
x[max_len - seq_lengths[i + n_seqs] :] for i, x in enumerate(raw_preds)
]
Expand Down Expand Up @@ -242,3 +246,6 @@ def get_exon_score(

out[seq_id] = self._apply_metric([acceptor, donor])
return out

def _init_model(self):
self.predict_batch_pangolin = batch_function_pangolin(self.model_nums)

0 comments on commit 8205fbb

Please sign in to comment.