Skip to content

Commit

Permalink
Code rollback for completely update over dna encoder version and data…
Browse files Browse the repository at this point in the history
…loader.
  • Loading branch information
zmgong committed Dec 17, 2024
1 parent 5468512 commit 08991d7
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 173 deletions.
2 changes: 1 addition & 1 deletion bioscanclip/config/global_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ insect_data:
image_dir: ${insect_data.dir}/images
species_to_other: ${insect_data.dir}/specie_to_other_labels.json
save_ckpt: true
barcodeBERT_checkpoint: ${project_root_path}/ckpt/BarcodeBERT/5_mer/model_41.pth
bioscan_bert_checkpoint: ${project_root_path}/ckpt/BarcodeBERT/5_mer/model_41.pth
current_version: 0.1
model_output_dir: ${project_root_path}/ckpt/bioscan_clip

Expand Down
24 changes: 1 addition & 23 deletions bioscanclip/model/dna_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,6 @@

device = "cuda" if torch.cuda.is_available() else "cpu"

def remove_extra_pre_fix(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith("module."):
key = key[7:]
new_state_dict[key] = value
return new_state_dict

def load_pre_trained_bioscan_bert_trained_with_5m(bioscan_bert_checkpoint, k=4):
ckpt = torch.load(bioscan_bert_checkpoint, map_location=device)
model_ckpt = remove_extra_pre_fix(ckpt["model"])
bert_config = BertConfig(**ckpt["bert_config"])
model = BertForMaskedLM(bert_config)
model.load_state_dict(model_ckpt, strict=False)
return model.to(device)


def load_pre_trained_bioscan_bert(bioscan_bert_checkpoint, k=5):
kmer_iter = (["".join(kmer)] for kmer in product("ACGT", repeat=k))
Expand Down Expand Up @@ -117,13 +101,7 @@ def reset_parameters(self) -> None:
nn.init.zeros_(w_B.weight)

def forward(self, x: Tensor) -> Tensor:
# if isinstance(self.dna_encoder.lora_barcode_bert, BertForMaskedLM):
# sequences = x[0]
# att_mask = x[1]
# labels = x[2]
# return self.lora_barcode_bert(sequences, attention_mask=att_mask, labels=labels).hidden_states[-1].logits.softmax(dim=-1).mean(dim=1)
# else:
# return self.lora_barcode_bert(x).logits.softmax(dim=-1).mean(dim=1)

return self.lora_barcode_bert(x).logits.softmax(dim=-1).mean(dim=1)


Expand Down
21 changes: 3 additions & 18 deletions bioscanclip/model/simple_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import torch.nn as nn
from bioscanclip.model.mlp import MLPEncoder
from bioscanclip.model.image_encoder import LoRA_ViT_timm, LoRA_ViT_OpenCLIP
from bioscanclip.model.dna_encoder import load_pre_trained_bioscan_bert, LoRA_barcode_bert, Freeze_DNA_Encoder, \
load_pre_trained_bioscan_bert_trained_with_5m
from bioscanclip.model.dna_encoder import load_pre_trained_bioscan_bert, LoRA_barcode_bert, Freeze_DNA_Encoder
from bioscanclip.model.language_encoder import load_pre_trained_bert, LoRA_bert, LoRA_bert_OpenCLIP
from bioscanclip.util.util import add_lora_layer_to_open_clip
import numpy as np
Expand Down Expand Up @@ -190,22 +189,8 @@ def load_clip_model(args, device=None):
if hasattr(args.model_config, 'dna'):
if args.model_config.dna.input_type == "sequence":
if dna_model == "barcode_bert" or dna_model == "lora_barcode_bert":

if hasattr(args.model_config, "barcodeBERT_ckpt_path"):
barcode_BERT_ckpt = args.model_config.barcodeBERT_ckpt_path
k=4
pre_trained_barcode_bert = load_pre_trained_bioscan_bert_trained_with_5m(
bioscan_bert_checkpoint=barcode_BERT_ckpt, k=k)


else:
barcode_BERT_ckpt = args.barcodeBERT_checkpoint
k = 5
pre_trained_barcode_bert = load_pre_trained_bioscan_bert(
bioscan_bert_checkpoint=barcode_BERT_ckpt, k=k)



pre_trained_barcode_bert = load_pre_trained_bioscan_bert(
bioscan_bert_checkpoint=args.bioscan_bert_checkpoint)
if disable_lora:
dna_encoder = LoRA_barcode_bert(model=pre_trained_barcode_bert, r=4,
num_classes=args.model_config.output_dim, lora_layer=[])
Expand Down
Loading

0 comments on commit 08991d7

Please sign in to comment.