From 7f4f53055ae82cb7bb21b64d2f400049e14cf792 Mon Sep 17 00:00:00 2001 From: zhao-ht Date: Mon, 10 Jul 2023 16:23:34 +0800 Subject: [PATCH 1/9] updata_file --- .gitignore | 6 ++++++ downstream_test.py | 6 +++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index b713501..39bc23b 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,9 @@ openai_api_key.json prompts_backup/ prompt_data/mol_cluster.csv prompt_data/mole_graph_property.csv +create_dataset.py +create_hug_repo.py +huggingface_dataset.py +download_huggingface_dataset.py +downstream_test_huggingface.py + diff --git a/downstream_test.py b/downstream_test.py index 9d7de9d..9436681 100644 --- a/downstream_test.py +++ b/downstream_test.py @@ -174,7 +174,7 @@ def train(model, loader, optimizer): -def downstream_task_by_transform(transform,model,train_loader,val_loader,test_loader,prompt=''): +def downstream_task_by_transform(model,train_loader,val_loader,test_loader,prompt=''): #reload the model parameter if args.few_shot: model = get_model(args, graph_args,tokenizer) @@ -533,7 +533,7 @@ def count_parameters(model): val_loader.dataset.transform = transform test_loader.dataset.transform = transform - downstream_task_by_transform(transform,model,train_loader,val_loader,test_loader,prompt[args.prompt_id[0]]) + downstream_task_by_transform(model,train_loader,val_loader,test_loader,prompt[args.prompt_id[0]]) elif args.prompt_policy == 'traversal': for prompt_id in args.prompt_id[str(single_split_label)]: @@ -550,7 +550,7 @@ def count_parameters(model): val_loader.dataset.transform = transform test_loader.dataset.transform = transform - downstream_task_by_transform(transform,model,train_loader,val_loader,test_loader,prompt[str(single_split_label)][prompt_id]) + downstream_task_by_transform(model,train_loader,val_loader,test_loader,prompt[str(single_split_label)][prompt_id]) else: raise ValueError('prompt_policy not implemented yet') From ac780d5fcd50ea911b9d7f1bd5041ccb8e18a4e4 Mon Sep 17 00:00:00 2001 From: zhao-ht Date: Thu, 13 Jul 2023 16:02:13 +0800 Subject: [PATCH 2/9] updata_file --- .gitignore | 3 + README.md | 90 ++- basic_pipeline.py | 2 +- dataloaders/galatica_smiles_collator.py | 3 +- dataloaders/gpt3_smiles_collator.py | 3 +- dataloaders/graph_text_transform.py | 4 +- dataloaders/kvplm_smiles_collator.py | 3 +- dataloaders/momu_collator.py | 5 +- downstream_test.py | 536 ++++++++++++------ ...MLETTransformerForConditionalGeneration.py | 10 +- model/__init__.py | 48 +- pretraining_gimlet.py | 82 +-- 12 files changed, 515 insertions(+), 274 deletions(-) diff --git a/.gitignore b/.gitignore index 39bc23b..cd2b6cb 100644 --- a/.gitignore +++ b/.gitignore @@ -31,8 +31,11 @@ prompts_backup/ prompt_data/mol_cluster.csv prompt_data/mole_graph_property.csv create_dataset.py +create_pretraining_dataset.py create_hug_repo.py huggingface_dataset.py download_huggingface_dataset.py downstream_test_huggingface.py +download_huggingface_model.py +debug.py diff --git a/README.md b/README.md index 925cda5..2223041 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,30 @@ GIMLET is a unified transformer model for both graph and text data and is pretra We also benchmark baselines including KVPLM, MoMu, and Galactica on our downstream tasks for instruction-based zero-shot learning. +## Updates + +### 2023.7.10 + +**1.** Now the datasets and the GIMLET model can be download directly from HuggingFace: [https://huggingface.co/datasets/haitengzhao/molecule_property_instruction](https://huggingface.co/datasets/haitengzhao/molecule_property_instruction) and [https://huggingface.co/haitengzhao/gimlet](https://huggingface.co/haitengzhao/gimlet). + +The GIMLET model can be downloaded and used as follows: + +``` +from model import GraphT5TransformerForConditionalGeneration +model = GraphT5TransformerForConditionalGeneration.from_pretrained("haitengzhao/gimlet") +``` + +Our datasets can be downloaded and used as follows: + +``` +from datasets import load_dataset +dataset = load_dataset("haitengzhao/molecule_property_instruction") +``` + +We have made updates to the pipeline and scripts to accommodate the new loading methods. Try out the new implementation in your projects and enjoy the improved experience! + +**2.** A few bugs in KVPLM testing have been fixed. + ## Installation To run GIMLET, please clone the repository to your local machine and install the required dependencies using the script provided. @@ -74,7 +98,19 @@ pip install openai ### Checkpoint Download -Please download pytorch_model.bin from [https://drive.google.com/file/d/1ROU4SLW2NF9EtT70JC_SHC1OZIPB90id/view?usp=sharing](https://drive.google.com/file/d/1ROU4SLW2NF9EtT70JC_SHC1OZIPB90id/view?usp=sharing) and move it to .\ckpts\gimlet. You can do this by the following scripts: + +#### Method 1: HuggingFace + +Our model can now be downloaded from HuggingFace. To download the model parameters, you can simply specify **--model_name_or_path** as **haitengzhao/gimlet**. Here's an example: + +``` +from model import GraphT5TransformerForConditionalGeneration +model = GraphT5TransformerForConditionalGeneration.from_pretrained("haitengzhao/gimlet") +``` + +#### Method 2: Manual Download + +You can also download pytorch_model.bin from [https://drive.google.com/file/d/1ROU4SLW2NF9EtT70JC_SHC1OZIPB90id/view?usp=sharing](https://drive.google.com/file/d/1ROU4SLW2NF9EtT70JC_SHC1OZIPB90id/view?usp=sharing) and move it to **.\ckpts\gimlet**. You can do this by the following scripts: ``` mkdir ckpts @@ -90,9 +126,16 @@ cd .. cd .. ``` +In this case, the **--model_name_or_path** refers to the path of the checkpoint directory, which is **ckpts/gimlet**. + ### Dataset Download +#### Method 1: HuggingFace +Our datasets is available for download on HuggingFace. You can automatically download the datasets and use the huggingface dataset pipeline by augment **--use_huggingface_pipeline**. + +#### Method 2: Manual Download +Alternatively, you can run experiments from the original molecule datasets. In this pipeline, we will incorporate instruction text to the molecule data during the experimentation process. The MoleculeNet datasets, which comprise pcba, bace, hiv, muv, tox21, toxcast, bbbp, esol, lipo, and freesolv, can be conveniently downloaded automatically upon the first run. Alternatively, you can manually download them by following the script below: ``` @@ -113,9 +156,12 @@ Besides MoleculeNet, we also includes CYP450 which can be downloaded from [https The script to run one downstream task is ``` -CUDA_VISIBLE_DEVICES=0 python downstream_test.py --zero_shot --transformer_backbone gimlet --model_name_or_path ckpts/gimlet --tokenizer_name t5-small --dataset bace --runseed 5 --batch_size 40 --grad_accum_step 1 --transform_in_collator +CUDA_VISIBLE_DEVICES=0 python downstream_test.py --zero_shot --transformer_backbone gimlet --model_name_or_path haitengzhao/gimlet --tokenizer_name t5-small --dataset bace --runseed 5 --batch_size 40 --grad_accum_step 1 --transform_in_collator --only_test --use_huggingface_pipeline ``` +You have the option to include the **--use_huggingface_pipeline** flag to utilize the HuggingFace dataset pipeline. This feature is applicable for both GIMLET and baseline models in downstream scenarios involving zero-shot and few-shot settings. + + To execute all the downstream tasks, you can utilize the script downstream_test.sh. Running this script will generate results that will be written into the file "./cache/testing_$modelname.csv". ``` bash downstream_test.sh $device $backbone $modelname_or_path ($few_shot_number) ($augment_type) @@ -149,7 +195,6 @@ bash downstream_test.sh 0 kvplm_aug ckpt_KV.pt 0 rewrite bash downstream_test.sh 0 momu_aug littlegin=graphclinit_bert=scibert_epoch=299-step=18300.pt 0 rewrite ``` - ## Run Few-Shot Learning You can run few-shot learning for all the downstream tasks by specify the few-shot number: @@ -170,9 +215,25 @@ bash downstream_test.sh 0 momu_fewshot littlegin=graphclinit_bert=scibert_epoch= ## Run the Pretraining -### Pretraining Data +### Run the Pretraining + +To reproduce the pretraining on Chembl and Chembl property datasets, you can run the following command: +``` +CUDA_VISIBLE_DEVICES=0 python pretraining_gimlet.py --model_name_or_path t5-small --tokenizer_name t5-small --transformer_backbone gimlet --do_train --train_file haitengzhao/molecule_property_instruction --transform_in_collator --per_device_train_batch_size 64 --gradient_accumulation_steps 1 --per_device_eval_batch_size 200 --line_by_line --loss_reduction_method sentence --save_steps 10000 --output_dir ckpts/gimlet_new +``` -You can download the pretraining dataset if you want to reproduce the pretraining or train your own model. The Chembl dataset can be downloaded and processed by the following steps: +You can validate the pretrained model on the splitted Chembl dataset (Chembl Zero Shot): + +``` +CUDA_VISIBLE_DEVICES=0 python pretraining_gimlet.py --model_name_or_path ckpts/gimlet_new --tokenizer_name t5-small --transformer_backbone gimlet --do_eval --validation_file haitengzhao/molecule_property_instruction --transform_in_collator --per_device_train_batch_size 64 --gradient_accumulation_steps 1 --per_device_eval_batch_size 200 --line_by_line --loss_reduction_method sentence --save_steps 10000 --output_dir ckpts/gimlet_new +``` + +You can run your own pretraining by specifying --train_file as your pretraining file, or imply your model into the pipeline. + + +### Reproducing the Pretraining Data Generation + +You can reproduce the pretraining dataset generation if you want to imply your own instruction methods. The Chembl dataset can be downloaded and processed by the following steps: ``` cd prompt_data/ @@ -208,6 +269,7 @@ cd .. Produce the pretraining dataset by the following script: ``` +cd prompts python generate_pretrain_dataset.py --generate_assay_text --generate_mole_text --split_non_overlap --add_negation --use_augmented_prompt ``` @@ -217,23 +279,7 @@ And merge the generated dataset together: python generate_pretrain_dataset_merge.py --merge_file_list assay_graph_text_train_non_overlap_split_0.csv assay_graph_text_detail_train_non_overlap_split_0.csv assay_graph_text_expand_train_non_overlap_split_0.csv assay_graph_text_rewrite_train_non_overlap_split_0.csv assay_graph_text_shorten_train_non_overlap_split_0.csv property_graph_text_negative05_train_non_overlap_split_0.csv property_graph_text_negative05_detail_train_non_overlap_split_0.csv property_graph_text_negative05_expand_train_non_overlap_split_0.csv property_graph_text_negative05_rewrite_train_non_overlap_split_0.csv property_graph_text_negative05_shorten_train_non_overlap_split_0.csv --merge_file_policy custom --merge_file_ratio 1.0 1.0 1.0 1.0 1.0 1.0 0.25 0.25 0.25 0.25 --final_file_name merge_split0.csv ``` -### Run the Pretraining - -After creating the pretraining datasets, you can reproduce the pretraining by yourself: - -``` -CUDA_VISIBLE_DEVICES=0 python pretraining_gimlet.py --model_name_or_path t5-small --tokenizer_name t5-small --transformer_backbone gimlet --do_train --train_file pretrain_datasets/merge_split0.csv --transform_in_collator --per_device_train_batch_size 64 --gradient_accumulation_steps 1 --per_device_eval_batch_size 200 --line_by_line --loss_reduction_method sentence --save_steps 10000 --output_dir ckpts/gimlet_new -``` - -You can validate the pretrained model on the splitted Chembl dataset (Chembl Zero Shot): - -``` -CUDA_VISIBLE_DEVICES=0 python pretraining_gimlet.py --model_name_or_path ckpts/gimlet_new --tokenizer_name t5-small --transformer_backbone gimlet --do_eval --validation_file pretrain_datasets/assay_graph_text_valid_non_overlap_split_0.csv --transform_in_collator --per_device_train_batch_size 64 --gradient_accumulation_steps 1 --per_device_eval_batch_size 200 --line_by_line --loss_reduction_method sentence --save_steps 10000 --output_dir ckpts/gimlet_new -``` - - -You can run your own pretraining by specifying --train_file as your pretraining file, or imply your model into the pipeline. - +In this scenario, the pretraining data is the file "pretrain_datasets/merge_split0.csv". To validate the pretrained model, you can use the data file "pretrain_datasets/assay_graph_text_valid_non_overlap_split_0.csv". To specify these files as the training and validation data, use the arguments **--train_file** and **--validation_file** with their respective file paths. ## Citation diff --git a/basic_pipeline.py b/basic_pipeline.py index 77749d1..7b3fcec 100644 --- a/basic_pipeline.py +++ b/basic_pipeline.py @@ -66,7 +66,7 @@ def eval_result(model, loader,label_dict,tokenizer,task_type,transformer_backbon batch[key] = batch[key].to(model.device) with torch.no_grad(): labels=batch["labels"] - if labels.shape[1]>1: # Yes + if labels.shape[1]>1 and not transformer_backbone in ['kvplm']: # Yes assert all((labels[:,1]==tokenizer.eos_token_id) + (labels[:,1]==id_invalid)) labels=labels[:,0].unsqueeze(1) del batch["labels"] diff --git a/dataloaders/galatica_smiles_collator.py b/dataloaders/galatica_smiles_collator.py index 9656420..e6a7e91 100644 --- a/dataloaders/galatica_smiles_collator.py +++ b/dataloaders/galatica_smiles_collator.py @@ -67,9 +67,10 @@ def torch_call(self, examples): def galactica_conditional_generation_tokenizer(examples,tokenizer,text_column_name,padding,max_seq_length,**kwargs): data_new = {} + text = examples[text_column_name] if isinstance(examples[text_column_name], str) else examples[text_column_name][0] tokenized_input = tokenizer( # examples[text_column_name]+ ' ', - '[START_I_SMILES]' + examples['graph'] + '[END_I_SMILES]\n\n##Question: ' + examples[text_column_name] + '\n\nAnswer:', + '[START_I_SMILES]' + examples['graph'] + '[END_I_SMILES]\n\n##Question: ' + text + '\n\nAnswer:', padding=padding, truncation=True, max_length=max_seq_length, diff --git a/dataloaders/gpt3_smiles_collator.py b/dataloaders/gpt3_smiles_collator.py index 8facb9a..55975fd 100644 --- a/dataloaders/gpt3_smiles_collator.py +++ b/dataloaders/gpt3_smiles_collator.py @@ -24,8 +24,9 @@ def gpt3_conditional_generation_tokenizer(examples,tokenizer,text_column_name,padding,max_seq_length,**kwargs): data_new = {} + text = examples[text_column_name] if isinstance(examples[text_column_name], str) else examples[text_column_name][0] tokenized_input = tokenizer( - 'Please answer questions on this molecule. The SMILES of this molecule is:' + examples['graph'] + '\n\n##Question: ' + examples[text_column_name] + '\n\nAnswer:', + 'Please answer questions on this molecule. The SMILES of this molecule is:' + examples['graph'] + '\n\n##Question: ' + text + '\n\nAnswer:', padding=padding, truncation=True, max_length=max_seq_length, diff --git a/dataloaders/graph_text_transform.py b/dataloaders/graph_text_transform.py index 5f9a9f1..0f9c02c 100644 --- a/dataloaders/graph_text_transform.py +++ b/dataloaders/graph_text_transform.py @@ -71,7 +71,7 @@ def tokenize_function_gin_T5(examples,tokenizer,text_column_name,padding,max_seq # Remove empty lines # examples[text_column_name] = [line for line in examples[text_column_name] if len(line) > 0 and not line.isspace()] text = tokenizer( - examples[text_column_name], + examples[text_column_name] if isinstance(examples[text_column_name],str) else examples[text_column_name][0], padding=padding, truncation=True, max_length=max_seq_length, @@ -107,7 +107,7 @@ def tokenize_function_gimlet(examples, tokenizer, text_column_name, padding, max # Remove empty lines # examples[text_column_name] = [line for line in examples[text_column_name] if len(line) > 0 and not line.isspace()] text = tokenizer( - examples[text_column_name], + examples[text_column_name] if isinstance(examples[text_column_name],str) else examples[text_column_name][0], # if examples[text_column_name] is list padding=padding, truncation=True, max_length=max_seq_length, diff --git a/dataloaders/kvplm_smiles_collator.py b/dataloaders/kvplm_smiles_collator.py index 15fab00..299558d 100644 --- a/dataloaders/kvplm_smiles_collator.py +++ b/dataloaders/kvplm_smiles_collator.py @@ -113,9 +113,10 @@ def torch_call(self, examples): def kvplm_conditional_generation_tokenizer(examples,tokenizer,text_column_name,padding,max_seq_length,**kwargs): data_new = {} + text=examples[text_column_name] if isinstance(examples[text_column_name],str) else examples[text_column_name][0] tokenized_input = tokenizer( examples['graph'] + ' '+ - examples[text_column_name]+ ' ', + text+ ' ', padding=padding, truncation=True, max_length=max_seq_length, diff --git a/dataloaders/momu_collator.py b/dataloaders/momu_collator.py index 4101e3c..8e08e67 100644 --- a/dataloaders/momu_collator.py +++ b/dataloaders/momu_collator.py @@ -23,8 +23,9 @@ def contrastive_conditional_generation_tokenizer(examples,tokenizer,text_column_name,padding,max_seq_length,rich_features,**kwargs): label_dict={'Yes':[1],'No':[0]} data_new = {} - tokenized_input_pos=tokenizer(examples[text_column_name]+' '+'Yes',truncation=True,max_length=512) - tokenized_input_neg=tokenizer(examples[text_column_name]+' '+'No',truncation=True,max_length=512) + text=examples[text_column_name] if isinstance(examples[text_column_name],str) else examples[text_column_name][0] + tokenized_input_pos=tokenizer(text+' '+'Yes',truncation=True,max_length=512) + tokenized_input_neg=tokenizer(text+' '+'No',truncation=True,max_length=512) # if not transform_in_collator: # examples['graph'] = smiles2graph(examples['graph']) data_new['graph']=examples['graph'] diff --git a/downstream_test.py b/downstream_test.py index 9436681..ccd59b4 100644 --- a/downstream_test.py +++ b/downstream_test.py @@ -17,7 +17,7 @@ from model import get_model from dataloaders import add_prompt_transform_dict,\ graph_text_collator_dict, \ - MoleculeDatasetSplitLabel + MoleculeDatasetSplitLabel,graph_text_tokenizer_dict from transformers import ( AutoTokenizer, @@ -26,6 +26,7 @@ from tqdm import tqdm import os import re +from datasets import load_dataset os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -39,10 +40,12 @@ parser.add_argument('--disable_tqdm',action='store_true') # about dataset and dataloader +parser.add_argument('--use_huggingface_pipeline',action='store_true') parser.add_argument('--dataset', type=str, default='bace') parser.add_argument('--num_workers', type=int, default=0) parser.add_argument('--rich_features',action='store_true') parser.add_argument('--transform_in_collator',action='store_true') +parser.add_argument('--overwrite_data_cache',action='store_true') # about multitask strategies parser.add_argument('--task_policy',type=str,default='traversal', choices=['single','traversal','multi_mixture','multi_label']) @@ -324,65 +327,79 @@ def downstream_task_by_transform(model,train_loader,val_loader,test_loader,promp } tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, **tokenizer_kwargs) - if args.few_shot and args.few_shot_prompt_fashion!='traversal': - - def modify_name(name): - name = name.replace('.ckpt', '.pt') - name=name.replace('ckpts/','') - if name[-1]=='/': - name=name[:-1] - return name - - file_name=os.path.join('cache','result_'+args.few_shot_prompt_fashion+'_prompt_table.csv') - prompts_pd = pd.read_csv(file_name,index_col='unique_task_id') - rename_keys={} - for name in prompts_pd.columns: - rename_keys[name]=modify_name(name) - prompts_pd=prompts_pd.rename(columns=rename_keys) - prompt={} - model_name=modify_name(args.model_name_or_path) - for ind in range(get_num_task(args.dataset)): - if args.dataset + '@' + str(ind) in prompts_pd.index.values: - res=prompts_pd.loc[args.dataset+'@'+str(ind),model_name] - if pd.isna(res): - continue - prompt[str(ind)]=[res] + model=get_model(args,graph_args,tokenizer) - else: - if args.prompt_augmentation=='': - with open(os.path.join("prompts",args.prompt_file), 'r') as load_f: - prompts = commentjson.load(load_f) - prompt=prompts[args.dataset] - else: - with open(os.path.join("prompts",args.prompt_file), 'r') as load_f: - prompts = commentjson.load(load_f) - prompt_all=prompts[args.dataset] + def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + if args.return_model_size: + print('Model size: {}'.format(count_parameters(model))) + + + if not args.use_huggingface_pipeline: + #Load instruction files, and add them for molecule data. + if args.few_shot and args.few_shot_prompt_fashion!='traversal': + + def modify_name(name): + name = name.replace('.ckpt', '.pt') + name=name.replace('ckpts/','') + if name[-1]=='/': + name=name[:-1] + return name + + file_name=os.path.join('cache','result_'+args.few_shot_prompt_fashion+'_prompt_table.csv') + prompts_pd = pd.read_csv(file_name,index_col='unique_task_id') + rename_keys={} + for name in prompts_pd.columns: + rename_keys[name]=modify_name(name) + prompts_pd=prompts_pd.rename(columns=rename_keys) prompt={} - for key in prompt_all: - if args.prompt_augmentation in prompt_all[key]: - prompt[key]=prompt_all[key][args.prompt_augmentation] - else: - print('label split {} has no augmentation {}'.format(key, args.prompt_augmentation)) - - if isinstance(prompt,list): - prompt_token=tokenizer(prompt,return_special_tokens_mask=True) - input_ids = [item for item in prompt_token.data['input_ids']] - attention_mask = [item for item in prompt_token.data['attention_mask']] - if args.prompt_id is None: - args.prompt_id = list(range(len(prompt))) - elif isinstance(prompt,dict): - prompt_token={} - input_ids={} - attention_mask={} - args.prompt_id={} - for key in prompt.keys(): - if len(prompt[key])>0: - prompt_token[key]=tokenizer(prompt[key],return_special_tokens_mask=True) - input_ids[key] = [item for item in prompt_token[key].data['input_ids']] - attention_mask[key] = [item for item in prompt_token[key].data['attention_mask']] - args.prompt_id[key] = list(range(len(prompt[key]))) + model_name=modify_name(args.model_name_or_path) + for ind in range(get_num_task(args.dataset)): + if args.dataset + '@' + str(ind) in prompts_pd.index.values: + res=prompts_pd.loc[args.dataset+'@'+str(ind),model_name] + if pd.isna(res): + continue + prompt[str(ind)]=[res] + + else: + if args.prompt_augmentation=='': + with open(os.path.join("prompts",args.prompt_file), 'r') as load_f: + prompts = commentjson.load(load_f) + prompt=prompts[args.dataset] + else: + with open(os.path.join("prompts",args.prompt_file), 'r') as load_f: + prompts = commentjson.load(load_f) + prompt_all=prompts[args.dataset] + prompt={} + for key in prompt_all: + if args.prompt_augmentation in prompt_all[key]: + prompt[key]=prompt_all[key][args.prompt_augmentation] + else: + print('label split {} has no augmentation {}'.format(key, args.prompt_augmentation)) + + if isinstance(prompt,list): + prompt_token=tokenizer(prompt,return_special_tokens_mask=True) + input_ids = [item for item in prompt_token.data['input_ids']] + attention_mask = [item for item in prompt_token.data['attention_mask']] + if args.prompt_id is None: + args.prompt_id = list(range(len(prompt))) + elif isinstance(prompt,dict): + prompt_token={} + input_ids={} + attention_mask={} + args.prompt_id={} + for key in prompt.keys(): + if len(prompt[key])>0: + prompt_token[key]=tokenizer(prompt[key],return_special_tokens_mask=True) + input_ids[key] = [item for item in prompt_token[key].data['input_ids']] + attention_mask[key] = [item for item in prompt_token[key].data['attention_mask']] + args.prompt_id[key] = list(range(len(prompt[key]))) + else: + raise ValueError('Prompt type not supported. Only list or dict of (list of) prompts are supported.') + else: - raise ValueError('Prompt type not supported. Only list or dict of (list of) prompts are supported.') + print('Using huggingface pipeline. Prompt file not loaded.') label_ignore = [-100] raw_label = {1: 'Yes', 0: 'No', 'invalid': label_ignore} @@ -393,168 +410,311 @@ def modify_name(name): # Bunch of classification tasks num_tasks = get_num_task(args.dataset) - dataset_folder = 'property_data/' - if args.transformer_backbone in ['kvplm', 'galactica','gpt3']: - dataset = MoleculeDatasetSplitLabel(root=dataset_folder, name=args.dataset,return_smiles=True,split_label=args.split_label,single_split=args.single_split,rich_features=args.rich_features) - else: - dataset = MoleculeDatasetSplitLabel(root=dataset_folder, name=args.dataset,split_label=args.split_label,single_split=args.single_split,rich_features=args.rich_features) - - print(dataset) - print(dataset[0]) - - - if args.split == 'scaffold': - # if args.single_split is not None: - smiles_list = pd.read_csv(dataset_folder + args.dataset + '/processed/smiles.csv', - header=None)[0].tolist() - train_index, valid_index, test_index = scaffold_split( - torch.arange(len(smiles_list)), smiles_list, null_value=0, frac_train=0.8, - frac_valid=0.1, frac_test=0.1) - - train_index_total=[] - valid_index_total=[] - test_index_total=[] - for times in range(dataset.label_number): - train_index_times=train_index+times*dataset.len_oridata() - valid_index_times = valid_index + times * dataset.len_oridata() - test_index_times = test_index + times * dataset.len_oridata() - - train_index_total.append(train_index_times) - valid_index_total.append(valid_index_times) - test_index_total.append(test_index_times) - train_index_total=torch.cat(train_index_total,0) - valid_index_total=torch.cat(valid_index_total,0) - test_index_total=torch.cat(test_index_total,0) - - train_dataset = dataset[train_index_total] - valid_dataset = dataset[valid_index_total] - test_dataset = dataset[test_index_total] - - print('split via scaffold') - elif args.split == 'random': - train_dataset, valid_dataset, test_dataset = random_split( - dataset, null_value=0, frac_train=0.8, frac_valid=0.1, - frac_test=0.1, seed=args.seed) - print('randomly split') - elif args.split == 'random_scaffold': - smiles_list = pd.read_csv(dataset_folder + args.dataset + '/processed/smiles.csv', - header=None)[0].tolist() - train_dataset, valid_dataset, test_dataset = random_scaffold_split( - dataset, smiles_list, null_value=0, frac_train=0.8, - frac_valid=0.1, frac_test=0.1, seed=args.seed) - print('random scaffold') - else: - raise ValueError('Invalid split option.') - print(train_dataset[0]) + if not args.use_huggingface_pipeline: + # Loading Molecule Dataset + dataset_folder = 'property_data/' - data_collator = graph_text_collator_dict[args.transformer_backbone]( - tokenizer=tokenizer, - transform_in_collator=args.transform_in_collator, - rich_features=args.rich_features) + if args.transformer_backbone in ['kvplm', 'galactica','gpt3']: + dataset = MoleculeDatasetSplitLabel(root=dataset_folder, name=args.dataset,return_smiles=True,split_label=args.split_label,single_split=args.single_split,rich_features=args.rich_features) + else: + dataset = MoleculeDatasetSplitLabel(root=dataset_folder, name=args.dataset,split_label=args.split_label,single_split=args.single_split,rich_features=args.rich_features) - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, - shuffle=True, num_workers=args.num_workers,collate_fn=data_collator) - val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, - shuffle=False, num_workers=args.num_workers,collate_fn=data_collator) - test_loader = DataLoader(test_dataset, batch_size=args.batch_size, - shuffle=False, num_workers=args.num_workers,collate_fn=data_collator) + print(dataset) + print(dataset[0]) - model=get_model(args,graph_args,tokenizer) + if args.split == 'scaffold': + # if args.single_split is not None: + smiles_list = pd.read_csv(dataset_folder + args.dataset + '/processed/smiles.csv', + header=None)[0].tolist() + train_index, valid_index, test_index = scaffold_split( + torch.arange(len(smiles_list)), smiles_list, null_value=0, frac_train=0.8, + frac_valid=0.1, frac_test=0.1) - def count_parameters(model): - return sum(p.numel() for p in model.parameters() if p.requires_grad) + train_index_total=[] + valid_index_total=[] + test_index_total=[] + for times in range(dataset.label_number): + train_index_times=train_index+times*dataset.len_oridata() + valid_index_times = valid_index + times * dataset.len_oridata() + test_index_times = test_index + times * dataset.len_oridata() + + train_index_total.append(train_index_times) + valid_index_total.append(valid_index_times) + test_index_total.append(test_index_times) + train_index_total=torch.cat(train_index_total,0) + valid_index_total=torch.cat(valid_index_total,0) + test_index_total=torch.cat(test_index_total,0) + + train_dataset = dataset[train_index_total] + valid_dataset = dataset[valid_index_total] + test_dataset = dataset[test_index_total] + + print('split via scaffold') + elif args.split == 'random': + train_dataset, valid_dataset, test_dataset = random_split( + dataset, null_value=0, frac_train=0.8, frac_valid=0.1, + frac_test=0.1, seed=args.seed) + print('randomly split') + elif args.split == 'random_scaffold': + smiles_list = pd.read_csv(dataset_folder + args.dataset + '/processed/smiles.csv', + header=None)[0].tolist() + train_dataset, valid_dataset, test_dataset = random_scaffold_split( + dataset, smiles_list, null_value=0, frac_train=0.8, + frac_valid=0.1, frac_test=0.1, seed=args.seed) + print('random scaffold') + else: + raise ValueError('Invalid split option.') + print(train_dataset[0]) - if args.return_model_size: - print('Model size: {}'.format(count_parameters(model))) + data_collator = graph_text_collator_dict[args.transformer_backbone]( + tokenizer=tokenizer, + transform_in_collator=args.transform_in_collator, + rich_features=args.rich_features) + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, + shuffle=True, num_workers=args.num_workers,collate_fn=data_collator) + val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, + shuffle=False, num_workers=args.num_workers,collate_fn=data_collator) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, + shuffle=False, num_workers=args.num_workers,collate_fn=data_collator) - if args.task_policy =='traversal': - recurrent_range=range(num_tasks) - elif args.task_policy =='single': - recurrent_range = [args.single_split] else: - raise ValueError('prompt_policy not implemented yet') + # Loading Huggingface Dataset + dataset = load_dataset("haitengzhao/molecule_property_instruction", + # download_mode = "force_redownload" + )[args.dataset] + - if args.not_retest_tasks_in_result_file: - if os.path.exists(args.output_result_to_file): - result_file=pd.read_csv(args.output_result_to_file,header=0,index_col=0) + print(dataset) + print(dataset[0]) + + if args.split == 'scaffold': + train_dataset_total = dataset.filter(lambda example: (example["split"] == 'train')) + valid_dataset_total = dataset.filter(lambda example: (example["split"] == 'valid')) + test_dataset_total = dataset.filter(lambda example: (example["split"] == 'test')) else: - result_file=None - - for single_split_label in recurrent_range: - if args.task_policy in ['traversal','single']: - print('label split: ',single_split_label) - if not str(single_split_label) in prompt: - print('No prompt for label split {}'.format(single_split_label)) - continue - if args.not_retest_tasks_in_result_file and result_file is not None: - if len(result_file[(result_file['dataset']==args.dataset) & (result_file['split']==single_split_label)])>0: - print(args.dataset,' ',single_split_label,'has been tested') - continue + raise ValueError('Not implied split option for huggingface pipeline.') - train_loader.dataset.set_single_split(single_split_label) - val_loader.dataset.set_single_split(single_split_label) - test_loader.dataset.set_single_split(single_split_label) - - dataset.set_single_split(single_split_label) - if args.few_shot is not None: - ind_each_class = {} - for ind in train_index_total: - label=int(dataset[ind].y) - if label not in ind_each_class: - ind_each_class[label]=[ind] - else: - ind_each_class[label].append(ind) + def select_single_prompt(example, prompt_id): + example["text"] = example["text"][prompt_id] + return example - for key in ind_each_class.keys(): - ind_each_class[key]=np.random.choice(ind_each_class[key], size=min(len(ind_each_class[key]),args.few_shot),replace=False).tolist() - train_index_total=[] - for key in ind_each_class.keys(): - train_index_total+=ind_each_class[key] - train_dataset = dataset[train_index_total] - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, - shuffle=True, num_workers=args.num_workers, collate_fn=data_collator) - train_loader.dataset.set_single_split(single_split_label) + tokenize_function = lambda x: graph_text_tokenizer_dict[args.transformer_backbone](examples=x, + tokenizer=tokenizer, + text_column_name='text', + padding=False, + max_seq_length=None, + rich_features=args.rich_features, + transform_in_collator=( + args.transform_in_collator)) - if args.prompt_policy == 'single': - print(prompt[args.prompt_id[0]]) + data_collator = graph_text_collator_dict[args.transformer_backbone]( + tokenizer=tokenizer, + transform_in_collator=args.transform_in_collator, + rich_features=args.rich_features) - #add prompt to graph data by data transform - transform=lambda x: add_prompt_transform_dict[args.transformer_backbone]( - data=x,data_label=x.y,input_ids=input_ids[args.prompt_id[0]], - attention_mask=attention_mask[args.prompt_id[0]],label_dict=label_dict, - rich_features=args.rich_features,transform_in_collator=args.transform_in_collator, - raw_prompts=prompt[args.prompt_id[0]],raw_label=raw_label,tokenizer=tokenizer, - generaltive_label=(task_type(args.dataset)=='reg')) - train_loader.dataset.transform = transform - val_loader.dataset.transform = transform - test_loader.dataset.transform = transform + if not args.use_huggingface_pipeline: #Different pre-processing for the two types of pipelines. - downstream_task_by_transform(model,train_loader,val_loader,test_loader,prompt[args.prompt_id[0]]) + if args.task_policy =='traversal': + recurrent_range=range(num_tasks) + elif args.task_policy =='single': + recurrent_range = [args.single_split] + else: + raise ValueError('prompt_policy not implemented yet') - elif args.prompt_policy == 'traversal': - for prompt_id in args.prompt_id[str(single_split_label)]: - print(prompt[str(single_split_label)][prompt_id]) + if args.not_retest_tasks_in_result_file: + if os.path.exists(args.output_result_to_file): + result_file=pd.read_csv(args.output_result_to_file,header=0,index_col=0) + else: + result_file=None + + for single_split_label in recurrent_range: + if args.task_policy in ['traversal','single']: + print('label split: ',single_split_label) + if not str(single_split_label) in prompt: + print('No prompt for label split {}'.format(single_split_label)) + continue + if args.not_retest_tasks_in_result_file and result_file is not None: + if len(result_file[(result_file['dataset']==args.dataset) & (result_file['split']==single_split_label)])>0: + print(args.dataset,' ',single_split_label,'has been tested') + continue + train_loader.dataset.set_single_split(single_split_label) + val_loader.dataset.set_single_split(single_split_label) + test_loader.dataset.set_single_split(single_split_label) + + dataset.set_single_split(single_split_label) + if args.few_shot is not None: + ind_each_class = {} + for ind in train_index_total: + label=int(dataset[ind].y) + if label not in ind_each_class: + ind_each_class[label]=[ind] + else: + ind_each_class[label].append(ind) + + for key in ind_each_class.keys(): + ind_each_class[key]=np.random.choice(ind_each_class[key], size=min(len(ind_each_class[key]),args.few_shot),replace=False).tolist() + train_index_total=[] + for key in ind_each_class.keys(): + train_index_total+=ind_each_class[key] + + train_dataset = dataset[train_index_total] + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, + shuffle=True, num_workers=args.num_workers, collate_fn=data_collator) + train_loader.dataset.set_single_split(single_split_label) + + if args.prompt_policy == 'single': + print(prompt[args.prompt_id[0]]) + + #add prompt to graph data by data transform transform=lambda x: add_prompt_transform_dict[args.transformer_backbone]( - data=x,data_label=x.y,input_ids=input_ids[str(single_split_label)][prompt_id], - attention_mask=attention_mask[str(single_split_label)][prompt_id],label_dict=label_dict, - rich_features=args.rich_features,transform_in_collator=args.transform_in_collator, - raw_prompts=prompt[str(single_split_label)][prompt_id],raw_label=raw_label,tokenizer=tokenizer, + data=x,data_label=x.y,input_ids=input_ids[args.prompt_id[0]], + attention_mask=attention_mask[args.prompt_id[0]],label_dict=label_dict, + rich_features=args.rich_features,transform_in_collator=args.transform_in_collator, + raw_prompts=prompt[args.prompt_id[0]],raw_label=raw_label,tokenizer=tokenizer, generaltive_label=(task_type(args.dataset)=='reg')) train_loader.dataset.transform = transform val_loader.dataset.transform = transform test_loader.dataset.transform = transform - downstream_task_by_transform(model,train_loader,val_loader,test_loader,prompt[str(single_split_label)][prompt_id]) + downstream_task_by_transform(model,train_loader,val_loader,test_loader,prompt[args.prompt_id[0]]) + elif args.prompt_policy == 'traversal': + for prompt_id in args.prompt_id[str(single_split_label)]: + print(prompt[str(single_split_label)][prompt_id]) + + transform=lambda x: add_prompt_transform_dict[args.transformer_backbone]( + data=x,data_label=x.y,input_ids=input_ids[str(single_split_label)][prompt_id], + attention_mask=attention_mask[str(single_split_label)][prompt_id],label_dict=label_dict, + rich_features=args.rich_features,transform_in_collator=args.transform_in_collator, + raw_prompts=prompt[str(single_split_label)][prompt_id],raw_label=raw_label,tokenizer=tokenizer, + generaltive_label=(task_type(args.dataset)=='reg')) + + train_loader.dataset.transform = transform + val_loader.dataset.transform = transform + test_loader.dataset.transform = transform + + downstream_task_by_transform(model,train_loader,val_loader,test_loader,prompt[str(single_split_label)][prompt_id]) + + else: + raise ValueError('prompt_policy not implemented yet') + + else: #Huggingface pipelie + + if args.task_policy == 'traversal': + recurrent_range = range(num_tasks) + elif args.task_policy == 'single': + recurrent_range = [args.single_split] else: raise ValueError('prompt_policy not implemented yet') + if args.not_retest_tasks_in_result_file: + if os.path.exists(args.output_result_to_file): + result_file = pd.read_csv(args.output_result_to_file, header=0, index_col=0) + else: + result_file = None + + for single_split_label in recurrent_range: + if args.task_policy in ['traversal', 'single']: + print('label split: ', single_split_label) + + if args.not_retest_tasks_in_result_file and result_file is not None: + if len(result_file[ + (result_file['dataset'] == args.dataset) & ( + result_file['split'] == single_split_label)]) > 0: + print(args.dataset, ' ', single_split_label, 'has been tested') + continue + + train_dataset_task = train_dataset_total.filter( + lambda example: (example["task_index"] == str(single_split_label))) + valid_dataset_task = valid_dataset_total.filter( + lambda example: (example["task_index"] == str(single_split_label))) + test_dataset_task = test_dataset_total.filter(lambda example: (example["task_index"] == str(single_split_label))) + + if len(test_dataset_task) == 0: + print('No label or prompt for label split {}'.format(single_split_label)) + continue + + if args.prompt_policy == 'single': + # print() + prompt_id_range = [args.prompt_id[0]] + elif args.prompt_policy == 'traversal': + if args.prompt_id is None: + prompt_id_range = range(len(train_dataset_task[0]['text'])) + else: + prompt_id_range = args.prompt_id[str(single_split_label)] + else: + raise ValueError('prompt_policy not implemented yet') + + for prompt_id in prompt_id_range: + + train_dataset = train_dataset_task.map(lambda example: select_single_prompt(example, prompt_id)) + valid_dataset = valid_dataset_task.map(lambda example: select_single_prompt(example, prompt_id)) + test_dataset = test_dataset_task.map(lambda example: select_single_prompt(example, prompt_id)) + + prompt = train_dataset[0]['text'] + print(prompt) + + train_dataset = train_dataset.map( + tokenize_function, + batched=False, + num_proc=None, + remove_columns=['text'], + load_from_cache_file=not args.overwrite_data_cache, + desc="Running tokenizer on dataset line_by_line", + ) + valid_dataset = valid_dataset.map( + tokenize_function, + batched=False, + num_proc=None, + remove_columns=['text'], + load_from_cache_file=not args.overwrite_data_cache, + desc="Running tokenizer on dataset line_by_line", + ) + test_dataset = test_dataset.map( + tokenize_function, + batched=False, + num_proc=None, + remove_columns=['text'], + load_from_cache_file=not args.overwrite_data_cache, + desc="Running tokenizer on dataset line_by_line", + ) + + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, + shuffle=True, num_workers=args.num_workers, collate_fn=data_collator) + val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, + shuffle=False, num_workers=args.num_workers, collate_fn=data_collator) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, + shuffle=False, num_workers=args.num_workers, collate_fn=data_collator) + + if args.few_shot is not None: + ind_each_class = {} + for ind, data in enumerate(train_dataset): + label = data['label'] + if label not in ind_each_class: + ind_each_class[label] = [ind] + else: + ind_each_class[label].append(ind) + + for key in ind_each_class.keys(): + ind_each_class[key] = np.random.choice(ind_each_class[key], + size=min(len(ind_each_class[key]), args.few_shot), + replace=False).tolist() + train_index_total = [] + for key in ind_each_class.keys(): + train_index_total += ind_each_class[key] + + train_dataset = train_dataset.select(train_index_total) + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, + shuffle=True, num_workers=args.num_workers, collate_fn=data_collator) + + downstream_task_by_transform(model, train_loader, val_loader, test_loader, + prompt) diff --git a/model/GIMLET/GIMLETTransformerForConditionalGeneration.py b/model/GIMLET/GIMLETTransformerForConditionalGeneration.py index 2c436b2..7c96a81 100644 --- a/model/GIMLET/GIMLETTransformerForConditionalGeneration.py +++ b/model/GIMLET/GIMLETTransformerForConditionalGeneration.py @@ -7,6 +7,7 @@ BaseModelOutput, Seq2SeqLMOutput, ) +from transformers import AutoConfig,PretrainedConfig from model.GIMLET.GIMLETEncoderStack import GraphT5EncoderStack_dict import copy logger = logging.get_logger(__name__) @@ -33,7 +34,7 @@ class GraphT5TransformerForConditionalGeneration(T5ForConditionalGeneration): r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] - def __init__(self, config,graph_args): + def __init__(self, config,graph_args=None): #for debug # config.dropout_rate=0.0 @@ -44,6 +45,13 @@ def __init__(self, config,graph_args): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False + + if graph_args is None: + assert hasattr(config,'graph_args') + graph_args= PretrainedConfig.from_dict(config.graph_args) + else: + config.graph_args = vars(graph_args) + self.config.loss_reduction_method = getattr(graph_args,'loss_reduction_method') self.encoder = GraphT5EncoderStack_dict[graph_args.transformer_backbone]\ (encoder_config,graph_args, self.shared) diff --git a/model/__init__.py b/model/__init__.py index 211e7d5..3da6dff 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -21,26 +21,36 @@ def get_model(args,graph_args,tokenizer): if not (args.transformer_backbone in ['kvplm','momu','galactica','gpt3']): - config_kwargs = { - "cache_dir": None, - "revision": 'main', - "use_auth_token": None, - } - config = AutoConfig.from_pretrained(args.tokenizer_name, **config_kwargs) - config.vocab_size=len(tokenizer) - graph_args.transformer_backbone = args.transformer_backbone - model = GraphTransformer_dict[args.transformer_backbone].from_pretrained( - args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - graph_args=graph_args, - cache_dir=None, - revision='main', - use_auth_token=None, - ignore_mismatched_sizes=True, - ) - model.resize_token_embeddings(len(tokenizer)) + if args.model_name_or_path=='haitengzhao/gimlet': + model = GraphTransformer_dict[args.transformer_backbone].from_pretrained( + args.model_name_or_path, + ) + + else: #load from local file: + config_kwargs = { + "cache_dir": None, + "revision": 'main', + "use_auth_token": None, + } + config = AutoConfig.from_pretrained(args.tokenizer_name, **config_kwargs) + config.vocab_size=len(tokenizer) + graph_args.transformer_backbone = args.transformer_backbone + config.graph_args = vars(graph_args) #use the user-provided graph args + + model = GraphTransformer_dict[args.transformer_backbone].from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + # graph_args=graph_args, + cache_dir=None, + revision='main', + use_auth_token=None, + ignore_mismatched_sizes=True, + ) + model.resize_token_embeddings(len(tokenizer)) + + elif args.transformer_backbone == 'kvplm': model = GraphTransformer_dict[args.transformer_backbone](graph_args) elif args.transformer_backbone == 'momu': diff --git a/pretraining_gimlet.py b/pretraining_gimlet.py index cc0d0d4..0dfa576 100644 --- a/pretraining_gimlet.py +++ b/pretraining_gimlet.py @@ -5,7 +5,7 @@ from itertools import chain from typing import Optional import datasets -from datasets import load_dataset +from datasets import load_dataset,DatasetDict import transformers from transformers import ( CONFIG_MAPPING, @@ -185,18 +185,18 @@ class DataTrainingArguments: transform_in_collator: Optional[bool] = field(default=False) wrap_dataset: Optional[bool] = field(default=False) - def __post_init__(self): - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") - else: - if self.train_file is not None: - extension = self.train_file.split(".")[-1] - if extension not in ["csv", "json", "txt"]: - raise ValueError("`train_file` should be a csv, a json or a txt file.") - if self.validation_file is not None: - extension = self.validation_file.split(".")[-1] - if extension not in ["csv", "json", "txt"]: - raise ValueError("`validation_file` should be a csv, a json or a txt file.") + # def __post_init__(self): + # if self.dataset_name is None and self.train_file is None and self.validation_file is None: + # raise ValueError("Need either a dataset name or a training/validation file.") + # else: + # if self.train_file is not None: + # extension = self.train_file.split(".")[-1] + # if extension not in ["csv", "json", "txt"]: + # raise ValueError("`train_file` should be a csv, a json or a txt file.") + # if self.validation_file is not None: + # extension = self.validation_file.split(".")[-1] + # if extension not in ["csv", "json", "txt"]: + # raise ValueError("`validation_file` should be a csv, a json or a txt file.") # def eval_result(trainer,task_type='cla'): # @@ -416,29 +416,39 @@ def main(): # Set seed before initializing model. set_seed(training_args.seed) - if data_args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - if "validation" not in raw_datasets.keys(): - raw_datasets["validation"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[:{data_args.validation_split_percentage}%]", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - raw_datasets["train"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[{data_args.validation_split_percentage}%:]", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) + # if data_args.dataset_name is not None: + # # Downloading and loading a dataset from the hub. + # raw_datasets = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # cache_dir=model_args.cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + # if "validation" not in raw_datasets.keys(): + # raw_datasets["validation"] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=f"train[:{data_args.validation_split_percentage}%]", + # cache_dir=model_args.cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + # raw_datasets["train"] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=f"train[{data_args.validation_split_percentage}%:]", + # cache_dir=model_args.cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + if data_args.train_file=='haitengzhao/molecule_property_instruction' or data_args.validation_file=='haitengzhao/molecule_property_instruction': + raw_datasets={} + dataset_full=load_dataset("haitengzhao/molecule_property_instruction", + # download_mode = "force_redownload" + ) + if training_args.do_train: + raw_datasets['train']=dataset_full['chembl_pretraining'] + if training_args.do_eval: + raw_datasets['validation']=dataset_full['chembl_zero_shot'] + raw_datasets=DatasetDict(raw_datasets) else: data_files = {} if data_args.train_file is not None: From 4df94ad767f1701aebeefb445ed1817001f2e05c Mon Sep 17 00:00:00 2001 From: zhao-ht Date: Thu, 13 Jul 2023 16:07:27 +0800 Subject: [PATCH 3/9] updata_file --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2223041..ef5b5b7 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ We also benchmark baselines including KVPLM, MoMu, and Galactica on our downstre ### 2023.7.10 -**1.** Now the datasets and the GIMLET model can be download directly from HuggingFace: [https://huggingface.co/datasets/haitengzhao/molecule_property_instruction](https://huggingface.co/datasets/haitengzhao/molecule_property_instruction) and [https://huggingface.co/haitengzhao/gimlet](https://huggingface.co/haitengzhao/gimlet). +**1.** Now the datasets and the GIMLET model can be download directly from HuggingFace 🤗 : [https://huggingface.co/datasets/haitengzhao/molecule_property_instruction](https://huggingface.co/datasets/haitengzhao/molecule_property_instruction) and [https://huggingface.co/haitengzhao/gimlet](https://huggingface.co/haitengzhao/gimlet). The GIMLET model can be downloaded and used as follows: @@ -101,7 +101,7 @@ pip install openai #### Method 1: HuggingFace -Our model can now be downloaded from HuggingFace. To download the model parameters, you can simply specify **--model_name_or_path** as **haitengzhao/gimlet**. Here's an example: +Our model can now be downloaded from HuggingFace 🤗 . To download the model parameters, you can simply specify **--model_name_or_path** as **haitengzhao/gimlet**. Here's an example: ``` from model import GraphT5TransformerForConditionalGeneration @@ -132,7 +132,7 @@ In this case, the **--model_name_or_path** refers to the path of the checkpoint ### Dataset Download #### Method 1: HuggingFace -Our datasets is available for download on HuggingFace. You can automatically download the datasets and use the huggingface dataset pipeline by augment **--use_huggingface_pipeline**. +Our datasets is available for download on HuggingFace 🤗 . You can automatically download the datasets and use the huggingface dataset pipeline by augment **--use_huggingface_pipeline**. #### Method 2: Manual Download Alternatively, you can run experiments from the original molecule datasets. In this pipeline, we will incorporate instruction text to the molecule data during the experimentation process. From 36efd729c11f2829a6fb645e30d83db666838dd0 Mon Sep 17 00:00:00 2001 From: zhao-ht Date: Tue, 18 Jul 2023 21:52:41 +0800 Subject: [PATCH 4/9] updata_file --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ef5b5b7..bf9af17 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # GIMLET -This is the code for paper [GIMLET: A Unified Graph-Text Model for Instruction-Based Molecule Zero-Shot Learning](https://www.biorxiv.org/content/10.1101/2023.05.30.542904). +This is the code for paper [GIMLET: A Unified Graph-Text Model for Instruction-Based Molecule Zero-Shot Learning](https://arxiv.org/pdf/2306.13089.pdf). GIMLET is a unified transformer model for both graph and text data and is pretrained on large scale molecule tasks with instructions, towards instruction-based molecule zero-shot learning. The framework and pretraining & downstream tasks are as follows: @@ -283,7 +283,7 @@ In this scenario, the pretraining data is the file "pretrain_datasets/merge_spli ## Citation -Please cite our paper if you find it helpful. +Please cite our paper if you find it helpful or use our datasets. ``` @article{zhao2023gimlet, title={GIMLET: A Unified Graph-Text Model for Instruction-Based Molecule Zero-Shot Learning}, From e002706d985c7613d8ebfd751edc11717e1240a5 Mon Sep 17 00:00:00 2001 From: zhao-ht Date: Sun, 24 Sep 2023 20:06:55 +0800 Subject: [PATCH 5/9] updata_file --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index bf9af17..a2066b0 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,10 @@ We also benchmark baselines including KVPLM, MoMu, and Galactica on our downstre ## Updates +### 2023.9.24 + +Out work has been accepted at NeurIPS 2023! The camera ready paper is coming soon. + ### 2023.7.10 **1.** Now the datasets and the GIMLET model can be download directly from HuggingFace 🤗 : [https://huggingface.co/datasets/haitengzhao/molecule_property_instruction](https://huggingface.co/datasets/haitengzhao/molecule_property_instruction) and [https://huggingface.co/haitengzhao/gimlet](https://huggingface.co/haitengzhao/gimlet). From 8bd89a27c21be79ffb5de23370477ef4326d1c7f Mon Sep 17 00:00:00 2001 From: zhao-ht <49550068+zhao-ht@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:20:57 +0800 Subject: [PATCH 6/9] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a2066b0..18bd12d 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # GIMLET -This is the code for paper [GIMLET: A Unified Graph-Text Model for Instruction-Based Molecule Zero-Shot Learning](https://arxiv.org/pdf/2306.13089.pdf). +This is the code for paper [GIMLET: A Unified Graph-Text Model for Instruction-Based Molecule Zero-Shot Learning](https://arxiv.org/pdf/2306.13089.pdf) published at NeurIPS 2023. GIMLET is a unified transformer model for both graph and text data and is pretrained on large scale molecule tasks with instructions, towards instruction-based molecule zero-shot learning. The framework and pretraining & downstream tasks are as follows: From 9a2a573c8e44da4d0201e34b259137ebe93c5498 Mon Sep 17 00:00:00 2001 From: zhao-ht <49550068+zhao-ht@users.noreply.github.com> Date: Tue, 21 Nov 2023 19:04:48 +0800 Subject: [PATCH 7/9] Specify the transformers version (4.28.1) in environment --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 18bd12d..0eaba50 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ pip install torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl pip install torch_geometric==1.7.2 -git clone https://github.com/huggingface/transformers +git clone -b v4.28.1 https://github.com/huggingface/transformers cd transformers pip install --editable ./ From b3728ff400a3e996d1f1b8d648ec0e9bf3e3408c Mon Sep 17 00:00:00 2001 From: zhao-ht <49550068+zhao-ht@users.noreply.github.com> Date: Wed, 27 Dec 2023 23:07:20 +0800 Subject: [PATCH 8/9] Create LICENSE --- LICENSE | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..08465f5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 haiteng zhao + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. From ec741f40bec2d020af3f0338cbc1f7a6d278ea8d Mon Sep 17 00:00:00 2001 From: zhao-ht <49550068+zhao-ht@users.noreply.github.com> Date: Thu, 22 Feb 2024 16:00:29 +0800 Subject: [PATCH 9/9] Update README.md --- README.md | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 0eaba50..f9127e7 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ We also benchmark baselines including KVPLM, MoMu, and Galactica on our downstre ### 2023.9.24 -Out work has been accepted at NeurIPS 2023! The camera ready paper is coming soon. +Out work has been accepted at NeurIPS 2023! The camera ready paper is at [https://proceedings.neurips.cc/paper_files/paper/2023/file/129033c7c08be683059559e8d6bfd460-Paper-Conference.pdf](https://proceedings.neurips.cc/paper_files/paper/2023/file/129033c7c08be683059559e8d6bfd460-Paper-Conference.pdf). ### 2023.7.10 @@ -289,13 +289,12 @@ In this scenario, the pretraining data is the file "pretrain_datasets/merge_spli Please cite our paper if you find it helpful or use our datasets. ``` -@article{zhao2023gimlet, - title={GIMLET: A Unified Graph-Text Model for Instruction-Based Molecule Zero-Shot Learning}, - author={Zhao, Haiteng and Liu, Shengchao and Ma, Chang and Xu, Hannan and Fu, Jie and Deng, Zhi-Hong and Kong, Lingpeng and Liu, Qi}, - journal={bioRxiv}, - pages={2023--05}, - year={2023}, - publisher={Cold Spring Harbor Laboratory} +@article{zhao2024gimlet, + title={Gimlet: A unified graph-text model for instruction-based molecule zero-shot learning}, + author={Zhao, Haiteng and Liu, Shengchao and Chang, Ma and Xu, Hannan and Fu, Jie and Deng, Zhihong and Kong, Lingpeng and Liu, Qi}, + journal={Advances in Neural Information Processing Systems}, + volume={36}, + year={2024} } ```