Skip to content

Commit

Permalink
Use yaml file for glue config
Browse files Browse the repository at this point in the history
  • Loading branch information
shikher7 committed Dec 22, 2022
1 parent ffb22dd commit b5387f7
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 18 deletions.
71 changes: 53 additions & 18 deletions run_glue.sh
Original file line number Diff line number Diff line change
@@ -1,25 +1,60 @@

# include parse_yaml function
pip install -r requirements.txt
export MODEL_TYPE="bert"
ls
source ./yaml_parser.sh

# model path should be same location as this shell file
export MODEL_PATH="bert-base-uncased"
export TASK_NAME="RTE"
export GLUE_DIR="data/glue_data"
export OUTPUT_DIR="output_glue"
export MAX_SEQ_LENGTH=128
export PER_GPU_TRAIN_BATCH_SIZE=32
export LEARNING_RATE=2e-5
export NUM_TRAIN_EPOCHS=3.0
#python src/glue/download_glue_data.py
export output_dir="output_glue"
mkdir $output_dir
mkdir $output_dir/$config_task1_task_name
mkdir $output_dir/$config_task2_task_name
mkdir $output_dir/$config_task3_task_name
eval $(parse_yaml src/config/glue_config.yaml "config_")


echo $config_task1_model_type

python src/glue/run_glue.py \
--model_type $MODEL_TYPE \
--model_name_or_path $MODEL_PATH \
--task_name $TASK_NAME \
--model_type $config_task1_model_type \
--model_name_or_path $config_task1_model_path \
--task_name $config_task1_task_name \
--do_train \
--do_eval \
--do_lower_case \
--data_dir $GLUE_DIR/$TASK_NAME/ \
--max_seq_length $MAX_SEQ_LENGTH \
--per_gpu_train_batch_size $PER_GPU_TRAIN_BATCH_SIZE \
--learning_rate $LEARNING_RATE \
--num_train_epochs $NUM_TRAIN_EPOCHS \
--output_dir $OUTPUT_DIR
--data_dir $config_task1_glue_dir/$config_task1_task_name/ \
--max_seq_length $config_task1_max_seq_length \
--per_gpu_train_batch_size $config_task1_per_gpu_train_batch_size \
--learning_rate $config_task1_learning_rate \
--num_train_epochs $config_task1_num_train_epochs \
--output_dir $output_dir/$config_task1_task_name

python src/glue/run_glue.py \
--model_type $config_task2_model_type \
--model_name_or_path $config_task2_model_path \
--task_name $config_task2_task_name \
--do_train \
--do_eval \
--do_lower_case \
--data_dir $config_task2_glue_dir/$config_task2_task_name/ \
--max_seq_length $config_task2_max_seq_length \
--per_gpu_train_batch_size $config_task2_per_gpu_train_batch_size \
--learning_rate $config_task2_learning_rate \
--num_train_epochs $config_task2_num_train_epochs \
--output_dir $output_dir/$config_task2_task_name

python src/glue/run_glue.py \
--model_type $config_task3_model_type \
--model_name_or_path $config_task3_model_path \
--task_name $config_task3_task_name \
--do_train \
--do_eval \
--do_lower_case \
--data_dir $config_task3_glue_dir/$config_task3_task_name/ \
--max_seq_length $config_task3_max_seq_length \
--per_gpu_train_batch_size $config_task3_per_gpu_train_batch_size \
--learning_rate $config_task3_learning_rate \
--num_train_epochs $config_task3_num_train_epochs \
--output_dir $output_dir/$config_task3_task_name

27 changes: 27 additions & 0 deletions src/config/glue_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
task1:
model_type: "bert"
model_path: "bert-base-uncased"
task_name: "RTE"
glue_dir: "data/glue_data"
max_seq_length: 128
per_gpu_train_batch_size: 32
learning_rate: 2e-5
num_train_epochs: 3.0
task2:
model_type: "bert"
model_path: "bert-base-uncased"
task_name: "QQP"
glue_dir: "data/glue_data"
max_seq_length: 128
per_gpu_train_batch_size: 32
learning_rate: 2e-5
num_train_epochs: 3.0
task3:
model_type: "bert"
model_path: "bert-base-uncased"
task_name: "CoLA"
glue_dir: "data/glue_data"
max_seq_length: 128
per_gpu_train_batch_size: 32
learning_rate: 2e-5
num_train_epochs: 3.0
16 changes: 16 additions & 0 deletions yaml_parser.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/sh
parse_yaml() {
local prefix=$2
local s='[[:space:]]*' w='[a-zA-Z0-9_]*' fs=$(echo @|tr @ '\034')
sed -ne "s|^\($s\)\($w\)$s:$s\"\(.*\)\"$s\$|\1$fs\2$fs\3|p" \
-e "s|^\($s\)\($w\)$s:$s\(.*\)$s\$|\1$fs\2$fs\3|p" $1 |
awk -F$fs '{
indent = length($1)/2;
vname[indent] = $2;
for (i in vname) {if (i > indent) {delete vname[i]}}
if (length($3) > 0) {
vn=""; for (i=0; i<indent; i++) {vn=(vn)(vname[i])("_")}
printf("%s%s%s=\"%s\"\n", "'$prefix'",vn, $2, $3);
}
}'
}

0 comments on commit b5387f7

Please sign in to comment.