From 865cff3905c949d874d2a8678426c112138dad82 Mon Sep 17 00:00:00 2001 From: Ben Epstein Date: Thu, 21 Sep 2023 18:01:09 -0400 Subject: [PATCH 1/2] align names to rage2e params --- README.md | 4 ++-- dalm/cli.py | 4 ++-- dalm/eval/README.md | 4 ++-- dalm/eval/eval_rag.py | 6 +++--- dalm/eval/eval_retriever_only.py | 4 ++-- dalm/training/retriever_only/train_retriever_only.py | 10 +++++----- experiments/llama-index-10k/README.md | 4 ++-- 7 files changed, 18 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 8ac3a17..07020f5 100644 --- a/README.md +++ b/README.md @@ -123,13 +123,13 @@ To run retriever only eval (make sure you have the checkpoints in the project root) ```bash - python dalm/eval/eval_retriever_only.py --dataset_path qa_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints + python dalm/eval/eval_retriever_only.py --dataset_path qa_pairs_test.csv --retriever_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints ``` For the e2e eval ```bash -python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_model_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path rag_e2e_checkpoints/retriever --generator_peft_model_path rag_e2e_checkpoints/generator +python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path rag_e2e_checkpoints/retriever --generator_peft_model_path rag_e2e_checkpoints/generator ``` diff --git a/dalm/cli.py b/dalm/cli.py index 2a3b614..6b5ceae 100644 --- a/dalm/cli.py +++ b/dalm/cli.py @@ -159,7 +159,7 @@ def train_rag_e2e( @cli.command() def train_retriever_only( - model_name_or_path: Annotated[ + retriever_name_or_path: Annotated[ str, typer.Argument(help="Path to the model or identifier from huggingface.co/models.", show_default=False) ], dataset_path: Annotated[ @@ -238,7 +238,7 @@ def train_retriever_only( """End-to-end train an in-domain model, including the retriever and generator""" train_retriever( dataset_or_path=dataset_path, - model_name_or_path=model_name_or_path, + retriever_name_or_path=retriever_name_or_path, dataset_passage_col_name=dataset_passage_col_name, dataset_query_col_name=dataset_query_col_name, query_max_len=query_max_len, diff --git a/dalm/eval/README.md b/dalm/eval/README.md index 336e8a9..0f575ad 100644 --- a/dalm/eval/README.md +++ b/dalm/eval/README.md @@ -4,11 +4,11 @@ To run retriever only eval (make sure you have the checkpoints in the project root) ```bash - python dalm/eval/eval_retriever_only.py --dataset_path qa_paits_test.csv --retriever_model_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints + python dalm/eval/eval_retriever_only.py --dataset_path qa_paits_test.csv --retriever_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints ``` For the e2e eval ```bash -python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_model_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path retriever_only_checkpoints --generator_peft_model_path generator_only_checkpoints +python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_name_or_path "BAAI/bge-large-en" --generator_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path retriever_only_checkpoints --generator_peft_model_path generator_only_checkpoints ``` diff --git a/dalm/eval/eval_rag.py b/dalm/eval/eval_rag.py index b4ec53f..895161c 100644 --- a/dalm/eval/eval_rag.py +++ b/dalm/eval/eval_rag.py @@ -55,13 +55,13 @@ def parse_args() -> Namespace: ), ) parser.add_argument( - "--retriever_model_name_or_path", + "--retriever_name_or_path", type=str, help="Path to pretrained retriever model or model identifier from huggingface.co/models.", required=True, ) parser.add_argument( - "--generator_model_name_or_path", + "--generator_name_or_path", type=str, help="Path to pretrained generator model or model identifier from huggingface.co/models.", required=True, @@ -141,7 +141,7 @@ def main() -> None: # rag retriver and the generator (don't load new peft layers no need) rag_model = AutoModelForRagE2E( - args.retriever_model_name_or_path, args.generator_model_name_or_path, get_peft=False, use_bnb=False + args.retriever_name_or_path, args.generator_name_or_path, get_peft=False, use_bnb=False ) # load the test dataset diff --git a/dalm/eval/eval_retriever_only.py b/dalm/eval/eval_retriever_only.py index 13b52e8..d100114 100644 --- a/dalm/eval/eval_retriever_only.py +++ b/dalm/eval/eval_retriever_only.py @@ -54,7 +54,7 @@ def parse_args() -> Namespace: ), ) parser.add_argument( - "--retriever_model_name_or_path", + "--retriever_name_or_path", type=str, help="Path to pretrained retriever model or model identifier from huggingface.co/models.", required=True, @@ -104,7 +104,7 @@ def main() -> None: SELECTED_TORCH_DTYPE: Final[torch.dtype] = torch.float16 if args.torch_dtype == "float16" else torch.bfloat16 # rag retriver and the generator (don't load new peft layers no need) - retriever_model = AutoModelForSentenceEmbedding(args.retriever_model_name_or_path, get_peft=False, use_bnb=False) + retriever_model = AutoModelForSentenceEmbedding(args.retriever_name_or_path, get_peft=False, use_bnb=False) # load the test dataset test_dataset = ( diff --git a/dalm/training/retriever_only/train_retriever_only.py b/dalm/training/retriever_only/train_retriever_only.py index 1073cf5..fafca84 100644 --- a/dalm/training/retriever_only/train_retriever_only.py +++ b/dalm/training/retriever_only/train_retriever_only.py @@ -67,7 +67,7 @@ def parse_args() -> Namespace: ), ) parser.add_argument( - "--model_name_or_path", + "--retriever_name_or_path", type=str, help="Path to pretrained model or model identifier from huggingface.co/models.", required=True, @@ -163,7 +163,7 @@ def parse_args() -> Namespace: def train_retriever( - model_name_or_path: str, + retriever_name_or_path: str, dataset_or_path: str | Dataset, dataset_passage_col_name: str = "Abstract", dataset_query_col_name: str = "Question", @@ -220,7 +220,7 @@ def train_retriever( os.makedirs(output_dir, exist_ok=True) accelerator.wait_for_everyone() - model = AutoModelForSentenceEmbedding(model_name_or_path, use_bnb=True, get_peft=use_peft) + model = AutoModelForSentenceEmbedding(retriever_name_or_path, use_bnb=True, get_peft=use_peft) tokenizer = model.tokenizer # dataset download and preprocessing @@ -417,7 +417,7 @@ def main() -> None: args = parse_args() train_retriever( dataset_or_path=args.dataset_path, - model_name_or_path=args.model_name_or_path, + retriever_name_or_path=args.retriever_name_or_path, dataset_passage_col_name=args.dataset_passage_col_name, dataset_query_col_name=args.dataset_query_col_name, query_max_len=args.query_max_len, @@ -449,5 +449,5 @@ def main() -> None: # python contrastive_train/peft_lora_constrastive_learning.py --dataset_path "xxxx.csv" \ -# --model_name_or_path "BAAI/bge-small-en" --output_dir "./retriever_only_checkpoints" --use_peft \ +# --retriever_name_or_path "BAAI/bge-small-en" --output_dir "./retriever_only_checkpoints" --use_peft \ # --with_tracking --report_to all --per_device_train_batch_size 30 diff --git a/experiments/llama-index-10k/README.md b/experiments/llama-index-10k/README.md index d4dfc75..85141f2 100644 --- a/experiments/llama-index-10k/README.md +++ b/experiments/llama-index-10k/README.md @@ -56,7 +56,7 @@ dalm train-rag-e2e \ And eval ``` -python ../../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path rag_e2e_checkpoints_bgsmall/retriever --embed_dim 384 +python ../../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path rag_e2e_checkpoints_bgsmall/retriever --embed_dim 384 ************* Retriever results: @@ -80,7 +80,7 @@ dalm train-retriever-only "BAAI/bge-small-en" "qa-outputs/question_answer_pairs. and eval ``` -python ../../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints_bgsmall/ --embed_dim 384 +python ../../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints_bgsmall/ --embed_dim 384 ************* Retriever results: From ec6c62781e11fbdb635349dc65df2437f8199057 Mon Sep 17 00:00:00 2001 From: Ben Epstein Date: Thu, 21 Sep 2023 18:35:54 -0400 Subject: [PATCH 2/2] align param names --- dalm/__init__.py | 2 +- dalm/cli.py | 24 ++++++++----------- dalm/eval/eval_rag.py | 4 ++-- dalm/eval/eval_retriever_only.py | 4 ++-- dalm/eval/utils.py | 8 +++---- dalm/training/rag_e2e/train_rage2e.py | 24 +++++++++---------- .../retriever_only/train_retriever_only.py | 16 ++++++------- .../utils/rag_e2e_dataloader_utils.py | 12 +++++----- .../utils/retriever_only_dataloader_utils.py | 8 +++---- experiments/llama-index-10k/README.md | 4 ++-- 10 files changed, 51 insertions(+), 55 deletions(-) diff --git a/dalm/__init__.py b/dalm/__init__.py index 27fdca4..81f0fde 100644 --- a/dalm/__init__.py +++ b/dalm/__init__.py @@ -1 +1 @@ -__version__ = "0.0.3" +__version__ = "0.0.4" diff --git a/dalm/cli.py b/dalm/cli.py index 6b5ceae..fcda96b 100644 --- a/dalm/cli.py +++ b/dalm/cli.py @@ -51,11 +51,9 @@ def train_rag_e2e( help="Path to pretrained (causal) generator or identifier from huggingface.co/models.", show_default=False ), ], - dataset_passage_col_name: Annotated[ - str, typer.Option(help="Name of the column containing the passage") - ] = "Abstract", - dataset_query_col_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question", - dataset_answer_col_name: Annotated[str, typer.Option(help="Name of the column containing the Answer")] = "Answer", + passage_column_name: Annotated[str, typer.Option(help="Name of the column containing the passage")] = "Abstract", + query_column_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question", + answer_column_name: Annotated[str, typer.Option(help="Name of the column containing the Answer")] = "Answer", query_max_len: Annotated[ int, typer.Option(help="The max query sequence length during tokenization. Longer sequences are truncated") ] = 50, @@ -129,9 +127,9 @@ def train_rag_e2e( dataset_or_path=dataset_path, retriever_name_or_path=retriever_name_or_path, generator_name_or_path=generator_name_or_path, - dataset_passage_col_name=dataset_passage_col_name, - dataset_query_col_name=dataset_query_col_name, - dataset_answer_col_name=dataset_answer_col_name, + passage_column_name=passage_column_name, + query_column_name=query_column_name, + answer_column_name=answer_column_name, query_max_len=query_max_len, passage_max_len=passage_max_len, generator_max_len=generator_max_len, @@ -169,10 +167,8 @@ def train_retriever_only( show_default=False, ), ], - dataset_passage_col_name: Annotated[ - str, typer.Option(help="Name of the column containing the passage") - ] = "Abstract", - dataset_query_col_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question", + passage_column_name: Annotated[str, typer.Option(help="Name of the column containing the passage")] = "Abstract", + query_column_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question", query_max_len: Annotated[ int, typer.Option(help="The max query sequence length during tokenization. Longer sequences are truncated") ] = 50, @@ -239,8 +235,8 @@ def train_retriever_only( train_retriever( dataset_or_path=dataset_path, retriever_name_or_path=retriever_name_or_path, - dataset_passage_col_name=dataset_passage_col_name, - dataset_query_col_name=dataset_query_col_name, + passage_column_name=passage_column_name, + query_column_name=query_column_name, query_max_len=query_max_len, passage_max_len=passage_max_len, per_device_train_batch_size=per_device_train_batch_size, diff --git a/dalm/eval/eval_rag.py b/dalm/eval/eval_rag.py index 895161c..c6065a3 100644 --- a/dalm/eval/eval_rag.py +++ b/dalm/eval/eval_rag.py @@ -162,8 +162,8 @@ def main() -> None: lambda example: preprocess_function( example, retriever_tokenizer, - query_col_name=args.query_column_name, - passage_col_name=args.passage_column_name, + query_column_name=args.query_column_name, + passage_column_name=args.passage_column_name, ), batched=True, # remove_columns=test_dataset.column_names, diff --git a/dalm/eval/eval_retriever_only.py b/dalm/eval/eval_retriever_only.py index d100114..c80093f 100644 --- a/dalm/eval/eval_retriever_only.py +++ b/dalm/eval/eval_retriever_only.py @@ -121,8 +121,8 @@ def main() -> None: lambda example: preprocess_function( example, retriever_tokenizer, - query_col_name=args.query_column_name, - passage_col_name=args.passage_column_name, + query_column_name=args.query_column_name, + passage_column_name=args.passage_column_name, ), batched=True, # remove_columns=test_dataset.column_names, diff --git a/dalm/eval/utils.py b/dalm/eval/utils.py index 4d2d1f6..d24a378 100644 --- a/dalm/eval/utils.py +++ b/dalm/eval/utils.py @@ -78,11 +78,11 @@ def calculate_precision_recall(retrieved_items: List, correct_items: List) -> Tu def preprocess_function( examples: LazyBatch, retriever_tokenizer: PreTrainedTokenizer, - query_col_name: str = "query", - passage_col_name: str = "passage", + query_column_name: str = "query", + passage_column_name: str = "passage", ) -> Dict[str, torch.Tensor]: - queries = examples[query_col_name] - passages = examples[passage_col_name] + queries = examples[query_column_name] + passages = examples[passage_column_name] # Tokenization for the retriever retriever_query_tokens = retriever_tokenizer(queries, padding="max_length", max_length=128, truncation=True) diff --git a/dalm/training/rag_e2e/train_rage2e.py b/dalm/training/rag_e2e/train_rage2e.py index 160c562..264b4da 100644 --- a/dalm/training/rag_e2e/train_rage2e.py +++ b/dalm/training/rag_e2e/train_rage2e.py @@ -60,13 +60,13 @@ def parse_args() -> Namespace: help=("Dataset path. Can be a huggingface dataset directory or a csv file."), ) parser.add_argument( - "--dataset_passage_col_name", type=str, default="Abstract", help="Name of the column containing the passage" + "--passage_column_name", type=str, default="Abstract", help="Name of the column containing the passage" ) parser.add_argument( - "--dataset_query_col_name", type=str, default="Question", help="Name of the column containing the query" + "--query_column_name", type=str, default="Question", help="Name of the column containing the query" ) parser.add_argument( - "--dataset_answer_col_name", type=str, default="Answer", help="Name of the column containing the answer" + "--answer_column_name", type=str, default="Answer", help="Name of the column containing the answer" ) parser.add_argument( "--query_max_len", @@ -217,9 +217,9 @@ def train_e2e( dataset_or_path: str | Dataset, retriever_name_or_path: str, generator_name_or_path: str, - dataset_passage_col_name: str = "Abstract", - dataset_query_col_name: str = "Question", - dataset_answer_col_name: str = "Answer", + passage_column_name: str = "Abstract", + query_column_name: str = "Question", + answer_column_name: str = "Answer", query_max_len: int = 50, passage_max_len: int = 128, generator_max_len: int = 256, @@ -295,9 +295,9 @@ def train_e2e( example, retriever_tokenizer=rag_model.retriever_tokenizer, generator_tokenizer=rag_model.generator_tokenizer, - dataset_query_col_name=dataset_query_col_name, - dataset_passage_col_name=dataset_passage_col_name, - dataset_answer_col_name=dataset_answer_col_name, + query_column_name=query_column_name, + passage_column_name=passage_column_name, + answer_column_name=answer_column_name, query_max_len=query_max_len, passage_max_len=passage_max_len, generator_max_len=generator_max_len, @@ -523,9 +523,9 @@ def main() -> None: dataset_or_path=args.dataset_path, retriever_name_or_path=args.retriever_name_or_path, generator_name_or_path=args.generator_name_or_path, - dataset_passage_col_name=args.dataset_passage_col_name, - dataset_query_col_name=args.dataset_query_col_name, - dataset_answer_col_name=args.dataset_answer_col_name, + passage_column_name=args.passage_column_name, + query_column_name=args.query_column_name, + answer_column_name=args.answer_column_name, query_max_len=args.query_max_len, passage_max_len=args.passage_max_len, generator_max_len=args.generator_max_len, diff --git a/dalm/training/retriever_only/train_retriever_only.py b/dalm/training/retriever_only/train_retriever_only.py index fafca84..ad78261 100644 --- a/dalm/training/retriever_only/train_retriever_only.py +++ b/dalm/training/retriever_only/train_retriever_only.py @@ -44,10 +44,10 @@ def parse_args() -> Namespace: parser = argparse.ArgumentParser(description="training a PEFT model for Sematic Search task") parser.add_argument("--dataset_path", type=str, default=None, help="dataset path in the local dir") parser.add_argument( - "--dataset_query_col_name", type=str, default="Question", help="Name of the query column in the dataset" + "--query_column_name", type=str, default="Question", help="Name of the query column in the dataset" ) parser.add_argument( - "--dataset_passage_col_name", type=str, default="Abstract", help="Name of the passage column in the dataset" + "--passage_column_name", type=str, default="Abstract", help="Name of the passage column in the dataset" ) parser.add_argument( "--query_max_len", @@ -165,8 +165,8 @@ def parse_args() -> Namespace: def train_retriever( retriever_name_or_path: str, dataset_or_path: str | Dataset, - dataset_passage_col_name: str = "Abstract", - dataset_query_col_name: str = "Question", + passage_column_name: str = "Abstract", + query_column_name: str = "Question", query_max_len: int = 50, passage_max_len: int = 128, per_device_train_batch_size: int = 32, @@ -236,8 +236,8 @@ def train_retriever( lambda example: preprocess_dataset( example, tokenizer, - query_col_name=dataset_query_col_name, - passage_col_name=dataset_passage_col_name, + query_column_name=query_column_name, + passage_column_name=passage_column_name, query_max_len=query_max_len, passage_max_len=passage_max_len, ), @@ -418,8 +418,8 @@ def main() -> None: train_retriever( dataset_or_path=args.dataset_path, retriever_name_or_path=args.retriever_name_or_path, - dataset_passage_col_name=args.dataset_passage_col_name, - dataset_query_col_name=args.dataset_query_col_name, + passage_column_name=args.passage_column_name, + query_column_name=args.query_column_name, query_max_len=args.query_max_len, passage_max_len=args.passage_max_len, per_device_train_batch_size=args.per_device_train_batch_size, diff --git a/dalm/training/utils/rag_e2e_dataloader_utils.py b/dalm/training/utils/rag_e2e_dataloader_utils.py index bf58011..6f62ad4 100644 --- a/dalm/training/utils/rag_e2e_dataloader_utils.py +++ b/dalm/training/utils/rag_e2e_dataloader_utils.py @@ -8,16 +8,16 @@ def preprocess_dataset( examples: LazyBatch, retriever_tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, generator_tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - dataset_query_col_name: str, - dataset_passage_col_name: str, - dataset_answer_col_name: str, + query_column_name: str, + passage_column_name: str, + answer_column_name: str, query_max_len: int, passage_max_len: int, generator_max_len: int, ) -> Dict[str, Any]: - querie_list = examples[dataset_query_col_name] - passage_list = examples[dataset_passage_col_name] - answers = examples[dataset_answer_col_name] + querie_list = examples[query_column_name] + passage_list = examples[passage_column_name] + answers = examples[answer_column_name] queries = [f"#query# {query}" for query in querie_list] passages = [f"#passage# {passage}" for passage in passage_list] diff --git a/dalm/training/utils/retriever_only_dataloader_utils.py b/dalm/training/utils/retriever_only_dataloader_utils.py index 9e2da0d..9f21601 100644 --- a/dalm/training/utils/retriever_only_dataloader_utils.py +++ b/dalm/training/utils/retriever_only_dataloader_utils.py @@ -8,17 +8,17 @@ def preprocess_dataset( examples: LazyBatch, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - query_col_name: str, - passage_col_name: str, + query_column_name: str, + passage_column_name: str, query_max_len: int, passage_max_len: int, ) -> Dict[str, torch.Tensor]: - query_list = examples[query_col_name] + query_list = examples[query_column_name] queries = [f"#query# {query}" for query in query_list] result_ = tokenizer(queries, padding="max_length", max_length=query_max_len, truncation=True) result_ = {f"query_{k}": v for k, v in result_.items()} - passage_list = examples[passage_col_name] + passage_list = examples[passage_column_name] passages = [f"#passage# {passage}" for passage in passage_list] result_passage = tokenizer(passages, padding="max_length", max_length=passage_max_len, truncation=True) for k, v in result_passage.items(): diff --git a/experiments/llama-index-10k/README.md b/experiments/llama-index-10k/README.md index 85141f2..a33c019 100644 --- a/experiments/llama-index-10k/README.md +++ b/experiments/llama-index-10k/README.md @@ -48,7 +48,7 @@ dalm train-rag-e2e \ "qa-outputs/question_answer_pairs.csv" \ "BAAI/bge-small-en" \ "meta-llama/Llama-2-7b-hf" \ ---dataset-passage-col-name text \ +--passage-column-name text \ --output-dir "rag_e2e_checkpoints_bgsmall" \ --no-with-tracking \ --per-device-train-batch-size 12 @@ -74,7 +74,7 @@ Train the retriever only dalm train-retriever-only "BAAI/bge-small-en" "qa-outputs/question_answer_pairs.csv" \ --output-dir "retriever_only_checkpoints_bgsmall" \ --use-peft \ ---dataset-passage-col-name text \ +--passage-column-name text \ --per-device-train-batch-size 150 ```