Skip to content

Commit

Permalink
add barcodeBERT pre-trained with 5M data
Browse files Browse the repository at this point in the history
  • Loading branch information
zmgong committed Dec 16, 2024
1 parent d17e7da commit 9246738
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 123 deletions.
2 changes: 1 addition & 1 deletion bioscanclip/config/global_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ insect_data:
image_dir: ${insect_data.dir}/images
species_to_other: ${insect_data.dir}/specie_to_other_labels.json
save_ckpt: true
bioscan_bert_checkpoint: ${project_root_path}/ckpt/BarcodeBERT/5_mer/model_41.pth
barcodeBERT_checkpoint: ${project_root_path}/ckpt/BarcodeBERT/5_mer/model_41.pth
current_version: 0.1
model_output_dir: ${project_root_path}/ckpt/bioscan_clip

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
batch_size: 500
epochs: 20
labels_for_driven_positive_and_negative_pairs:
wandb_project_name: BIOSCAN-CLIP-5M
using_train_seen_for_pre_train: true
dataset: bioscan_5m

image:
input_type: image
pre_train_model: vit_base_patch16_224
dna:
input_type: sequence
pre_train_model: barcode_bert
language:
input_type: sequence
pre_train_model: prajjwal1/bert-small

model_output_name: image_dna_text_4gpu
evaluation_period: 1

barcodeBERT_ckpt_path: ${project_root_path}/ckpt/barcodeBERT/trained_with_5m/(BIOSCAN-5M)-BEST_k4_6_6_w1_m0_r0.pt

ckpt_path: ${project_root_path}/ckpt/bioscan_clip/new_5M_training/trained_with_5M_image_dna_text/best.pth
output_dim: 768
port: 29531

disable_lora: true
lr_scheduler: one_cycle
lr_config:
lr: 1e-6
max_lr: 5e-5

all_gather: true
loss_setup:
gather_with_grad: true
use_horovod: false
local_loss: false
fix_temperature: false
amp: true

random_seed: false

eval_skip_epoch: 10

default_seed: 42

fine_tuning_set:
batch_size: 150
epochs: 15
fine_tune_model_output_dir: ${model_output_dir}/${model_config.model_output_name}/supervise_fine_tune_ckpt
104 changes: 0 additions & 104 deletions bioscanclip/model/barcode_bert.py

This file was deleted.

22 changes: 4 additions & 18 deletions bioscanclip/model/simple_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,27 +259,13 @@ def load_clip_model(args, device=None):

# For DNA part
if hasattr(args.model_config, 'dna'):
# if hasattr(args.model_config.dna, 'freeze') and args.model_config.dna.freeze:
# dna_encoder = Freeze_DNA_Encoder()
# elif args.model_config.dna.input_type == "sequence":
# if dna_model == "barcode_bert" or dna_model == "lora_barcode_bert":
# pre_trained_barcode_bert = load_pre_trained_bioscan_bert(
# bioscan_bert_checkpoint=args.bioscan_bert_checkpoint)
# if disable_lora:
# dna_encoder = LoRA_barcode_bert(model=pre_trained_barcode_bert, r=4,
# num_classes=args.model_config.output_dim, lora_layer=[])
# else:
# dna_encoder = LoRA_barcode_bert(model=pre_trained_barcode_bert, r=4,
# num_classes=args.model_config.output_dim)
# else:
# dna_encoder = MLPEncoder(input_dim=args.model_config.dna.input_dim,
# hidden_dim=args.model_config.dna.hidden_dim,
# output_dim=args.model_config.output_dim)
#
if args.model_config.dna.input_type == "sequence":
if dna_model == "barcode_bert" or dna_model == "lora_barcode_bert":
barcode_BERT_ckpt = args.BarcodeBERT_checkpoint
if hasattr(args.model_config, "barcodeBERT_ckpt_path"):
barcode_BERT_ckpt = args.model_config.barcodeBERT_ckpt_path
pre_trained_barcode_bert = load_pre_trained_bioscan_bert(
bioscan_bert_checkpoint=args.bioscan_bert_checkpoint)
bioscan_bert_checkpoint=barcode_BERT_ckpt)
if disable_lora:
dna_encoder = LoRA_barcode_bert(model=pre_trained_barcode_bert, r=4,
num_classes=args.model_config.output_dim, lora_layer=[])
Expand Down

0 comments on commit 9246738

Please sign in to comment.