Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update #1

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,12 @@ openai_api_key.json
prompts_backup/
prompt_data/mol_cluster.csv
prompt_data/mole_graph_property.csv
create_dataset.py
create_pretraining_dataset.py
create_hug_repo.py
huggingface_dataset.py
download_huggingface_dataset.py
downstream_test_huggingface.py
download_huggingface_model.py
debug.py

21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 haiteng zhao

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
113 changes: 81 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# GIMLET


This is the code for paper [GIMLET: A Unified Graph-Text Model for Instruction-Based Molecule Zero-Shot Learning](https://www.biorxiv.org/content/10.1101/2023.05.30.542904).
This is the code for paper [GIMLET: A Unified Graph-Text Model for Instruction-Based Molecule Zero-Shot Learning](https://arxiv.org/pdf/2306.13089.pdf) published at NeurIPS 2023.

GIMLET is a unified transformer model for both graph and text data and is pretrained on large scale molecule tasks with instructions, towards instruction-based molecule zero-shot learning. The framework and pretraining & downstream tasks are as follows:

Expand All @@ -13,6 +13,34 @@ GIMLET is a unified transformer model for both graph and text data and is pretra
We also benchmark baselines including KVPLM, MoMu, and Galactica on our downstream tasks for instruction-based zero-shot learning.


## Updates

### 2023.9.24

Out work has been accepted at NeurIPS 2023! The camera ready paper is at [https://proceedings.neurips.cc/paper_files/paper/2023/file/129033c7c08be683059559e8d6bfd460-Paper-Conference.pdf](https://proceedings.neurips.cc/paper_files/paper/2023/file/129033c7c08be683059559e8d6bfd460-Paper-Conference.pdf).

### 2023.7.10

**1.** Now the datasets and the GIMLET model can be download directly from HuggingFace 🤗 : [https://huggingface.co/datasets/haitengzhao/molecule_property_instruction](https://huggingface.co/datasets/haitengzhao/molecule_property_instruction) and [https://huggingface.co/haitengzhao/gimlet](https://huggingface.co/haitengzhao/gimlet).

The GIMLET model can be downloaded and used as follows:

```
from model import GraphT5TransformerForConditionalGeneration
model = GraphT5TransformerForConditionalGeneration.from_pretrained("haitengzhao/gimlet")
```

Our datasets can be downloaded and used as follows:

```
from datasets import load_dataset
dataset = load_dataset("haitengzhao/molecule_property_instruction")
```

We have made updates to the pipeline and scripts to accommodate the new loading methods. Try out the new implementation in your projects and enjoy the improved experience!

**2.** A few bugs in KVPLM testing have been fixed.

## Installation
To run GIMLET, please clone the repository to your local machine and install the required dependencies using the script provided.

Expand All @@ -36,7 +64,7 @@ pip install torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl

pip install torch_geometric==1.7.2

git clone https://github.com/huggingface/transformers
git clone -b v4.28.1 https://github.com/huggingface/transformers

cd transformers
pip install --editable ./
Expand Down Expand Up @@ -74,7 +102,19 @@ pip install openai


### Checkpoint Download
Please download pytorch_model.bin from [https://drive.google.com/file/d/1ROU4SLW2NF9EtT70JC_SHC1OZIPB90id/view?usp=sharing](https://drive.google.com/file/d/1ROU4SLW2NF9EtT70JC_SHC1OZIPB90id/view?usp=sharing) and move it to .\ckpts\gimlet. You can do this by the following scripts:

#### Method 1: HuggingFace

Our model can now be downloaded from HuggingFace 🤗 . To download the model parameters, you can simply specify **--model_name_or_path** as **haitengzhao/gimlet**. Here's an example:

```
from model import GraphT5TransformerForConditionalGeneration
model = GraphT5TransformerForConditionalGeneration.from_pretrained("haitengzhao/gimlet")
```

#### Method 2: Manual Download

You can also download pytorch_model.bin from [https://drive.google.com/file/d/1ROU4SLW2NF9EtT70JC_SHC1OZIPB90id/view?usp=sharing](https://drive.google.com/file/d/1ROU4SLW2NF9EtT70JC_SHC1OZIPB90id/view?usp=sharing) and move it to **.\ckpts\gimlet**. You can do this by the following scripts:

```
mkdir ckpts
Expand All @@ -90,9 +130,16 @@ cd ..
cd ..
```

In this case, the **--model_name_or_path** refers to the path of the checkpoint directory, which is **ckpts/gimlet**.


### Dataset Download

#### Method 1: HuggingFace
Our datasets is available for download on HuggingFace 🤗 . You can automatically download the datasets and use the huggingface dataset pipeline by augment **--use_huggingface_pipeline**.

#### Method 2: Manual Download
Alternatively, you can run experiments from the original molecule datasets. In this pipeline, we will incorporate instruction text to the molecule data during the experimentation process.
The MoleculeNet datasets, which comprise pcba, bace, hiv, muv, tox21, toxcast, bbbp, esol, lipo, and freesolv, can be conveniently downloaded automatically upon the first run. Alternatively, you can manually download them by following the script below:

```
Expand All @@ -113,9 +160,12 @@ Besides MoleculeNet, we also includes CYP450 which can be downloaded from [https
The script to run one downstream task is

```
CUDA_VISIBLE_DEVICES=0 python downstream_test.py --zero_shot --transformer_backbone gimlet --model_name_or_path ckpts/gimlet --tokenizer_name t5-small --dataset bace --runseed 5 --batch_size 40 --grad_accum_step 1 --transform_in_collator
CUDA_VISIBLE_DEVICES=0 python downstream_test.py --zero_shot --transformer_backbone gimlet --model_name_or_path haitengzhao/gimlet --tokenizer_name t5-small --dataset bace --runseed 5 --batch_size 40 --grad_accum_step 1 --transform_in_collator --only_test --use_huggingface_pipeline
```

You have the option to include the **--use_huggingface_pipeline** flag to utilize the HuggingFace dataset pipeline. This feature is applicable for both GIMLET and baseline models in downstream scenarios involving zero-shot and few-shot settings.


To execute all the downstream tasks, you can utilize the script downstream_test.sh. Running this script will generate results that will be written into the file "./cache/testing_$modelname.csv".
```
bash downstream_test.sh $device $backbone $modelname_or_path ($few_shot_number) ($augment_type)
Expand Down Expand Up @@ -149,7 +199,6 @@ bash downstream_test.sh 0 kvplm_aug ckpt_KV.pt 0 rewrite
bash downstream_test.sh 0 momu_aug littlegin=graphclinit_bert=scibert_epoch=299-step=18300.pt 0 rewrite
```


## Run Few-Shot Learning

You can run few-shot learning for all the downstream tasks by specify the few-shot number:
Expand All @@ -170,9 +219,25 @@ bash downstream_test.sh 0 momu_fewshot littlegin=graphclinit_bert=scibert_epoch=

## Run the Pretraining

### Pretraining Data
### Run the Pretraining

To reproduce the pretraining on Chembl and Chembl property datasets, you can run the following command:
```
CUDA_VISIBLE_DEVICES=0 python pretraining_gimlet.py --model_name_or_path t5-small --tokenizer_name t5-small --transformer_backbone gimlet --do_train --train_file haitengzhao/molecule_property_instruction --transform_in_collator --per_device_train_batch_size 64 --gradient_accumulation_steps 1 --per_device_eval_batch_size 200 --line_by_line --loss_reduction_method sentence --save_steps 10000 --output_dir ckpts/gimlet_new
```

You can validate the pretrained model on the splitted Chembl dataset (Chembl Zero Shot):

```
CUDA_VISIBLE_DEVICES=0 python pretraining_gimlet.py --model_name_or_path ckpts/gimlet_new --tokenizer_name t5-small --transformer_backbone gimlet --do_eval --validation_file haitengzhao/molecule_property_instruction --transform_in_collator --per_device_train_batch_size 64 --gradient_accumulation_steps 1 --per_device_eval_batch_size 200 --line_by_line --loss_reduction_method sentence --save_steps 10000 --output_dir ckpts/gimlet_new
```

You can run your own pretraining by specifying --train_file as your pretraining file, or imply your model into the pipeline.

You can download the pretraining dataset if you want to reproduce the pretraining or train your own model. The Chembl dataset can be downloaded and processed by the following steps:

### Reproducing the Pretraining Data Generation

You can reproduce the pretraining dataset generation if you want to imply your own instruction methods. The Chembl dataset can be downloaded and processed by the following steps:
```
cd prompt_data/

Expand Down Expand Up @@ -208,6 +273,7 @@ cd ..
Produce the pretraining dataset by the following script:

```
cd prompts
python generate_pretrain_dataset.py --generate_assay_text --generate_mole_text --split_non_overlap --add_negation --use_augmented_prompt
```

Expand All @@ -217,35 +283,18 @@ And merge the generated dataset together:
python generate_pretrain_dataset_merge.py --merge_file_list assay_graph_text_train_non_overlap_split_0.csv assay_graph_text_detail_train_non_overlap_split_0.csv assay_graph_text_expand_train_non_overlap_split_0.csv assay_graph_text_rewrite_train_non_overlap_split_0.csv assay_graph_text_shorten_train_non_overlap_split_0.csv property_graph_text_negative05_train_non_overlap_split_0.csv property_graph_text_negative05_detail_train_non_overlap_split_0.csv property_graph_text_negative05_expand_train_non_overlap_split_0.csv property_graph_text_negative05_rewrite_train_non_overlap_split_0.csv property_graph_text_negative05_shorten_train_non_overlap_split_0.csv --merge_file_policy custom --merge_file_ratio 1.0 1.0 1.0 1.0 1.0 1.0 0.25 0.25 0.25 0.25 --final_file_name merge_split0.csv
```

### Run the Pretraining

After creating the pretraining datasets, you can reproduce the pretraining by yourself:

```
CUDA_VISIBLE_DEVICES=0 python pretraining_gimlet.py --model_name_or_path t5-small --tokenizer_name t5-small --transformer_backbone gimlet --do_train --train_file pretrain_datasets/merge_split0.csv --transform_in_collator --per_device_train_batch_size 64 --gradient_accumulation_steps 1 --per_device_eval_batch_size 200 --line_by_line --loss_reduction_method sentence --save_steps 10000 --output_dir ckpts/gimlet_new
```

You can validate the pretrained model on the splitted Chembl dataset (Chembl Zero Shot):

```
CUDA_VISIBLE_DEVICES=0 python pretraining_gimlet.py --model_name_or_path ckpts/gimlet_new --tokenizer_name t5-small --transformer_backbone gimlet --do_eval --validation_file pretrain_datasets/assay_graph_text_valid_non_overlap_split_0.csv --transform_in_collator --per_device_train_batch_size 64 --gradient_accumulation_steps 1 --per_device_eval_batch_size 200 --line_by_line --loss_reduction_method sentence --save_steps 10000 --output_dir ckpts/gimlet_new
```


You can run your own pretraining by specifying --train_file as your pretraining file, or imply your model into the pipeline.

In this scenario, the pretraining data is the file "pretrain_datasets/merge_split0.csv". To validate the pretrained model, you can use the data file "pretrain_datasets/assay_graph_text_valid_non_overlap_split_0.csv". To specify these files as the training and validation data, use the arguments **--train_file** and **--validation_file** with their respective file paths.

## Citation

Please cite our paper if you find it helpful.
Please cite our paper if you find it helpful or use our datasets.
```
@article{zhao2023gimlet,
title={GIMLET: A Unified Graph-Text Model for Instruction-Based Molecule Zero-Shot Learning},
author={Zhao, Haiteng and Liu, Shengchao and Ma, Chang and Xu, Hannan and Fu, Jie and Deng, Zhi-Hong and Kong, Lingpeng and Liu, Qi},
journal={bioRxiv},
pages={2023--05},
year={2023},
publisher={Cold Spring Harbor Laboratory}
@article{zhao2024gimlet,
title={Gimlet: A unified graph-text model for instruction-based molecule zero-shot learning},
author={Zhao, Haiteng and Liu, Shengchao and Chang, Ma and Xu, Hannan and Fu, Jie and Deng, Zhihong and Kong, Lingpeng and Liu, Qi},
journal={Advances in Neural Information Processing Systems},
volume={36},
year={2024}
}
```

Expand Down
2 changes: 1 addition & 1 deletion basic_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def eval_result(model, loader,label_dict,tokenizer,task_type,transformer_backbon
batch[key] = batch[key].to(model.device)
with torch.no_grad():
labels=batch["labels"]
if labels.shape[1]>1: # Yes <s>
if labels.shape[1]>1 and not transformer_backbone in ['kvplm']: # Yes <s>
assert all((labels[:,1]==tokenizer.eos_token_id) + (labels[:,1]==id_invalid))
labels=labels[:,0].unsqueeze(1)
del batch["labels"]
Expand Down
3 changes: 2 additions & 1 deletion dataloaders/galatica_smiles_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ def torch_call(self, examples):
def galactica_conditional_generation_tokenizer(examples,tokenizer,text_column_name,padding,max_seq_length,**kwargs):

data_new = {}
text = examples[text_column_name] if isinstance(examples[text_column_name], str) else examples[text_column_name][0]
tokenized_input = tokenizer(
# examples[text_column_name]+ ' ',
'[START_I_SMILES]' + examples['graph'] + '[END_I_SMILES]\n\n##Question: ' + examples[text_column_name] + '\n\nAnswer:',
'[START_I_SMILES]' + examples['graph'] + '[END_I_SMILES]\n\n##Question: ' + text + '\n\nAnswer:',
padding=padding,
truncation=True,
max_length=max_seq_length,
Expand Down
3 changes: 2 additions & 1 deletion dataloaders/gpt3_smiles_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
def gpt3_conditional_generation_tokenizer(examples,tokenizer,text_column_name,padding,max_seq_length,**kwargs):

data_new = {}
text = examples[text_column_name] if isinstance(examples[text_column_name], str) else examples[text_column_name][0]
tokenized_input = tokenizer(
'Please answer questions on this molecule. The SMILES of this molecule is:' + examples['graph'] + '\n\n##Question: ' + examples[text_column_name] + '\n\nAnswer:',
'Please answer questions on this molecule. The SMILES of this molecule is:' + examples['graph'] + '\n\n##Question: ' + text + '\n\nAnswer:',
padding=padding,
truncation=True,
max_length=max_seq_length,
Expand Down
4 changes: 2 additions & 2 deletions dataloaders/graph_text_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def tokenize_function_gin_T5(examples,tokenizer,text_column_name,padding,max_seq
# Remove empty lines
# examples[text_column_name] = [line for line in examples[text_column_name] if len(line) > 0 and not line.isspace()]
text = tokenizer(
examples[text_column_name],
examples[text_column_name] if isinstance(examples[text_column_name],str) else examples[text_column_name][0],
padding=padding,
truncation=True,
max_length=max_seq_length,
Expand Down Expand Up @@ -107,7 +107,7 @@ def tokenize_function_gimlet(examples, tokenizer, text_column_name, padding, max
# Remove empty lines
# examples[text_column_name] = [line for line in examples[text_column_name] if len(line) > 0 and not line.isspace()]
text = tokenizer(
examples[text_column_name],
examples[text_column_name] if isinstance(examples[text_column_name],str) else examples[text_column_name][0], # if examples[text_column_name] is list
padding=padding,
truncation=True,
max_length=max_seq_length,
Expand Down
3 changes: 2 additions & 1 deletion dataloaders/kvplm_smiles_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ def torch_call(self, examples):
def kvplm_conditional_generation_tokenizer(examples,tokenizer,text_column_name,padding,max_seq_length,**kwargs):

data_new = {}
text=examples[text_column_name] if isinstance(examples[text_column_name],str) else examples[text_column_name][0]
tokenized_input = tokenizer(
examples['graph'] + ' '+
examples[text_column_name]+ ' ',
text+ ' ',
padding=padding,
truncation=True,
max_length=max_seq_length,
Expand Down
5 changes: 3 additions & 2 deletions dataloaders/momu_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
def contrastive_conditional_generation_tokenizer(examples,tokenizer,text_column_name,padding,max_seq_length,rich_features,**kwargs):
label_dict={'Yes':[1],'No':[0]}
data_new = {}
tokenized_input_pos=tokenizer(examples[text_column_name]+' '+'Yes',truncation=True,max_length=512)
tokenized_input_neg=tokenizer(examples[text_column_name]+' '+'No',truncation=True,max_length=512)
text=examples[text_column_name] if isinstance(examples[text_column_name],str) else examples[text_column_name][0]
tokenized_input_pos=tokenizer(text+' '+'Yes',truncation=True,max_length=512)
tokenized_input_neg=tokenizer(text+' '+'No',truncation=True,max_length=512)
# if not transform_in_collator:
# examples['graph'] = smiles2graph(examples['graph'])
data_new['graph']=examples['graph']
Expand Down
Loading